# Create training data and responses
import json
import os
import copy
import cohere
from elasticsearch.helpers import bulk
from elasticsearch import Elasticsearch
import redis
import anthropic  # Added for Claude API
import logging

from dotenv import load_dotenv
load_dotenv()
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
ANTHROPIC_MODEL = os.getenv("ANTHROPIC_MODEL", "claude-3-5-haiku-latest")

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "Qwen/Qwen3-Next-80B-A3B-Instruct")
OPENAI_URL = os.getenv("OPENAI_URL", "https://api.deepinfra.com/v1/openai")

from openai import OpenAI

# Create an OpenAI client with your token and endpoint
openai = OpenAI(
    api_key=OPENAI_API_KEY,
    base_url=OPENAI_URL,
)

# llm_model = os.getenv("GENAI_ENGINE", "claude")
llm_model = "open_llm"

# Initialize clients
global es_client
es_client = Elasticsearch("http://localhost:9200")

# Initialize Redis client
# r = redis.Redis(host='localhost', port=6379, db=0)

# Initialize Cohere client
CO_API_KEY = "Az1yGMw8rORRk87fO38h9fpmwqJ6ZY9IRA1ZGjEo"
co = cohere.Client(CO_API_KEY)

# Initialize Claude client
claude_model = ANTHROPIC_MODEL
CLAUDE_API_KEY = ANTHROPIC_API_KEY
claude_client = anthropic.Anthropic(api_key=CLAUDE_API_KEY)

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


def encode_cohere(chunks):
    co_model = "embed-v4.0"
    logger.info(f"Inside encode_cohere function with model: {co_model} and {len(chunks)}")

    response = co.embed(
        texts=chunks,
        model=co_model,
        input_type="search_document",
        embedding_types=['float']
    )
    return response.embeddings.float

def es_index_data(bid_id, index_config_file, data_file):
    logger.info(f"Creating the Elasticsearch index: {bid_id}")
    bid_index = bid_id.lower()
    
    if es_client.indices.exists(index=bid_index):
        return 10
        # es_client.indices.delete(index=bid_index, ignore=[404])
    
    f = open(index_config_file,)
    index_config = json.load(f)
    logger.info(f"index file for ES:\n {index_config}")
    es_client.indices.create(index=bid_index, settings=index_config["settings"], mappings=index_config["mappings"])
    
    # Creating index for sentences
    tag_responses = {}
    if os.path.isfile(data_file):
        with open(data_file) as data_json_file:
            docs = json.load(data_json_file)

        questions = []
        new_docs = []
        
        if len(docs) != 0:
            for faq in docs:
                new_faq = {}
                if isinstance(faq['question'], list):
                    for q in faq['question']:
                        if q and q.strip():  # Only add non-empty questions
                            questions.append(q.strip())
                            new_faq['tagName'] = faq['tagName']
                            d = copy.deepcopy(new_faq)
                            new_docs.append(d)
                else:
                    if faq['question'] and faq['question'].strip():  # Only add non-empty questions
                        questions.append(faq['question'].strip())
                        new_faq['tagName'] = faq['tagName']
                        d = copy.deepcopy(new_faq)
                        new_docs.append(d)
                    
                tag_responses[faq['tagName']] = []
                tag_responses[faq['tagName']].append(faq['answer'])

            if questions:  # Only proceed if we have valid questions
                logger.info(f"Processing {len(questions)} questions for embedding...")
                title_vectors = encode_cohere(questions)
                
                if len(title_vectors) != len(questions):
                    logger.info(f"ERROR: Vector count ({len(title_vectors)}) doesn't match question count ({len(questions)})")
                    return tag_responses

                requests = []
                successful_docs = 0
                failed_docs = 0
                
                for i, doc in enumerate(new_docs):
                    try:
                        # Validate vector before adding
                        if i < len(title_vectors) and title_vectors[i] is not None:
                            vector = title_vectors[i]
                            
                            # Ensure vector is the right dimension (1536 for embed-v4.0)
                            if len(vector) == 1536:
                                request = {
                                    **doc,
                                    "title_vector": vector
                                }
                                
                                # Validate required fields
                                if 'tagName' in request and request['tagName']:
                                    requests.append({
                                        "_op_type": "index",
                                        "_index": bid_index,
                                        "_id": str(i),  # Use string ID
                                        "_source": request
                                    })
                                    successful_docs += 1
                                else:
                                    logger.info(f"Skipping document {i}: missing tagName")
                                    failed_docs += 1
                            else:
                                logger.info(f"Skipping document {i}: invalid vector dimension {len(vector)}")
                                failed_docs += 1
                        else:
                            logger.info(f"Skipping document {i}: missing or null vector")
                            failed_docs += 1
                            
                    except Exception as e:
                        logger.info(f"Error processing document {i}: {e}")
                        failed_docs += 1

                logger.info(f"Prepared {successful_docs} documents for indexing, {failed_docs} failed")
                
                if requests:
                    try:
                        # Use bulk with error handling
                        success_count, failed_items = bulk(
                            es_client, 
                            requests,
                            index=bid_index,
                            request_timeout=60,
                            max_retries=3,
                            initial_backoff=2,
                            max_backoff=600
                        )
                        
                        logger.info(f"Successfully indexed {success_count} documents")
                        if failed_items:
                            logger.info(f"Failed to index {len(failed_items)} documents")
                            for item in failed_items[:5]:  # Show first 5 failures
                                logger.info(f"Failed item: {item}")
                        
                        es_client.indices.refresh(index=bid_index)
                        logger.info(f"Done indexing for {bid_index}")
                        
                    except BulkIndexError as e:
                        logger.info(f"Bulk indexing error: {e}")
                        logger.info("Error details:")
                        for error in e.errors[:5]:  # Show first 5 errors
                            logger.info(f"  - {error}")
                        
                        # Try to get successfully indexed count
                        try:
                            count_response = es_client.count(index=bid_index)
                            indexed_count = count_response['count']
                            logger.info(f"Despite errors, {indexed_count} documents were successfully indexed")
                        except:
                            logger.info("Could not determine how many documents were indexed")
                        
                        # Don't raise the error, continue with partial success
                        es_client.indices.refresh(index=bid_index)
                        
                    except Exception as e:
                        logger.info(f"Unexpected error during bulk indexing: {e}")
                        raise
                else:
                    logger.info("No valid documents to index")
            else:
                logger.info("No valid questions found to process")
            
    else:
        logger.info(f"Data file doesn't exist. File name: {data_file}")
    
    return tag_responses

def get_matched_answers(es_response, json_filepath):
    # Load JSON data once and create a lookup dictionary
    with open(json_filepath, 'r', encoding='utf-8') as f:
        json_data = json.load(f)
    
    # Create dictionary: {tagName: answer} for O(1) lookups
    tag_answer_map = {item['tagName']: item['answer'] for item in json_data}
    
    # Extract matched answers using list comprehension
    return [
        tag_answer_map[hit['_source']['tagName']]
        for hit in es_response['hits']['hits']
        if hit['_source']['tagName'] in tag_answer_map
    ]

def run_es_query(query, query_embeddings, client_id, bid_id, data_file):
    logger.info(f"Entered ES Query for client: {client_id}")
    if es_client.indices.exists(index=bid_id):
        logger.info(f"✅ {bid_id} index exists in ES")
    else:
        logger.info(f"⚠️ {bid_id} index does not exist in ES")

    query_vector = query_embeddings[0]

    script_query = {
        "script_score": {
            "query": {"match_all": {}},
            "script": {
                "source": "cosineSimilarity(params.query_vector, doc['title_vector']) + 1.0",
                "params": {"query_vector": query_vector}
            }
        }
    }

    es_response = es_client.search(
        index=bid_id,
        body={
            "size": 15,
            "query": script_query,
            "_source": {"includes": ["tagName"]}
        }
    )
        
    final_context = []
    
    logger.info(f"Data File path is {data_file}")
    # Check if any positive response has come
    if len(es_response["hits"]["hits"]) > 0:
        logger.info(f"✅ Positive responses returned in ES search")
        answer_doc_list = get_matched_answers(es_response, data_file)
        reranked_response = co.rerank(
            query=query,
            documents=answer_doc_list,
            top_n=8,
            model="rerank-v3.5",
        )
        logger.info(f"Answer doc list shows {answer_doc_list}")
        logger.info(f"Reranked response list shows {reranked_response}")
        for idx, res in enumerate(reranked_response.results):
            final_context.append(answer_doc_list[res.index])

    logger.info(f"Final FAQ list is: {final_context}")
    # r.jsondel('yubo:conversation_history_' + client_id)
    # r.jsonset('yubo:context_'+client_id, '$', final_context)
    # r.expire('yubo:context_'+client_id, 300)

    return final_context
    
def get_context(text, bid_id, client_id, data_file):
    """
    Get context for a given query using Elasticsearch.
    """
    logger.info(f"Entered get_intent for client id: {client_id}, With user input: {text}")
    
    context_list = []
    
    # Check if the index exists before querying
    if es_client.indices.exists(index=bid_id.lower()):
        text_embeddings = encode_cohere([text])
        context_list = run_es_query(text, text_embeddings, client_id, bid_id, data_file)
        
    return context_list

def get_claude_response(query, context_list, client_id, no_es_query):
    """
    Get response from Claude LLM using the context and user query.
    
    Args:
        query (str): The user's question
        context_list (list): List of context documents retrieved from Elasticsearch
        client_id (str): The client identifier
        
    Returns:
        str: Claude's response to the user query based on the context
    """
    logger.info(f"Getting Minaions response for client: {client_id}")
    
    if (no_es_query):
        formatted_context = context_list
    else:
        # Format the context as a string
        formatted_context = "\n\n".join(context_list)
    
    # Create the prompt for Claude
    prompt = f"""You are an AI assistant helping with RFP (Request for Proposal) questions.
    
CONTEXT:
{formatted_context}

USER QUERY:
{query}

Provide a helpful, accurate, and concise response to the user query based only on the information provided in the context above. If the context doesn't contain relevant information to answer the query, say "I don't have enough information to answer that question." 
DO NOT make up information not present in the context.

Return ONLY the response to the USER QUERY with no additional commentary.
"""
    if (llm_model == "claude"):
        # Get response from Claude
        response = claude_client.messages.create(
            model=claude_model,
            max_tokens=2000,
            temperature=0.2,
            system="You are an AI assistant helping with RFP (Request for Proposal) questions. Respond only with information found in the provided context.",
            messages=[
                {"role": "user", "content": prompt}
            ]
        )
        logger.info("Minaions response generated successfully")
        return response.content[0].text
    elif (llm_model == "open_llm"):
        response = openai.chat.completions.create(
            model=OPENAI_MODEL,
            messages=[
                {"role": "system", "content": "You are an AI assistant helping with RFP (Request for Proposal) questions. Respond only with information found in the provided context."},
                {"role": "user", "content": prompt},
            ],
        )
        logger.info("Minaions response generated successfully")
        return response.choices[0].message.content
    

def chat_with_rfp(query, bid_id, client_id, data_file, no_es_query, merged_txt_file):
    """
    Main function to handle the chat flow with RFP context retrieval and Claude response.
    
    Args:
        query (str): The user's question
        bid_id (str): The bid identifier for Elasticsearch index
        client_id (str): The client identifier
        
    Returns:
        str: Claude's response to the user query
    """
    if(no_es_query):
        claude_response = get_claude_response(query, merged_txt_file, client_id, no_es_query)
    else:
        # Get context from Elasticsearch
        context_list = get_context(query, bid_id, client_id, data_file)
        
        # Get response from Claude using the context
        claude_response = get_claude_response(query, context_list, client_id, no_es_query)
    
    return claude_response
