# Create training data and responses
import json
import os
import copy
import cohere
from elasticsearch.helpers import bulk
from elasticsearch import Elasticsearch
import redis
import anthropic  # Added for Claude API

# Initialize clients
global es_client
es_client = Elasticsearch("http://localhost:9200")

# Initialize Redis client
r = redis.Redis(host='localhost', port=6379, db=0)

# Initialize Cohere client
CO_API_KEY = "Az1yGMw8rORRk87fO38h9fpmwqJ6ZY9IRA1ZGjEo"
co = cohere.Client(CO_API_KEY)

# Initialize Claude client
CLAUDE_API_KEY = "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"
claude_client = anthropic.Anthropic(api_key=CLAUDE_API_KEY)

def encode_cohere(chunks):
    co_model = "embed-v4.0"
    print('Inside encode_cohere function with model:', co_model, len(chunks))

    response = co.embed(
        texts=chunks,
        model=co_model,
        input_type="search_document",
        embedding_types=['float']
    )
    return response.embeddings.float

def es_index_data(bid_id, index_config_file, data_file):
    print("Creating the Elasticsearch index:", bid_id)
    bid_index = bid_id.lower()
    
    if es_client.indices.exists(index=bid_index):
        es_client.indices.delete(index=bid_index, ignore=[404])
    
    f = open(index_config_file,)
    index_config = json.load(f)
    print("index file for ES:\n", index_config)
    es_client.indices.create(index=bid_index, settings=index_config["settings"], mappings=index_config["mappings"])
    
    # Creating index for sentences
    tag_responses = {}
    if os.path.isfile(data_file):
        with open(data_file) as data_json_file:
            docs = json.load(data_json_file)

        questions = []
        new_docs = []
        
        if len(docs) != 0:
            for faq in docs:
                new_faq = {}
                if isinstance(faq['question'], list):
                    for q in faq['question']:
                        questions.append(q)
                        new_faq['tagName'] = faq['tagName']
                        d = copy.deepcopy(new_faq)
                        new_docs.append(d)
                else:
                    questions.append(faq['question'])
                    new_faq['tagName'] = faq['tagName']
                    d = copy.deepcopy(new_faq)
                    new_docs.append(d)
                    
                tag_responses[faq['tagName']] = []
                tag_responses[faq['tagName']].append(faq['answer'])

            title_vectors = encode_cohere(questions)

            requests = []
            for i, doc in enumerate(new_docs):
                request = doc
                request["_op_type"] = "index"
                request["_index"] = bid_index
                request["title_vector"] = title_vectors[i]
                requests.append(request)

            bulk(es_client, requests)
            es_client.indices.refresh(index=bid_index)
            print("Done indexing for", bid_index)
            
    else:
        print("Data file doesn't exist. File name:", data_file)
    
    return tag_responses

def get_matched_answers(es_response, json_filepath):
    # Load JSON data once and create a lookup dictionary
    with open(json_filepath, 'r', encoding='utf-8') as f:
        json_data = json.load(f)
    
    # Create dictionary: {tagName: answer} for O(1) lookups
    tag_answer_map = {item['tagName']: item['answer'] for item in json_data}
    
    # Extract matched answers using list comprehension
    return [
        tag_answer_map[hit['_source']['tagName']]
        for hit in es_response['hits']['hits']
        if hit['_source']['tagName'] in tag_answer_map
    ]

def run_es_query(query, query_embeddings, client_id, bid_id):
    print("Entered ES Query for client:", client_id)
    query_vector = query_embeddings[0]

    script_query = {
        "script_score": {
            "query": {"match_all": {}},
            "script": {
                "source": "cosineSimilarity(params.query_vector, doc['title_vector']) + 1.0",
                "params": {"query_vector": query_vector}
            }
        }
    }

    es_response = es_client.search(
        index=bid_id,
        body={
            "size": 15,
            "query": script_query,
            "_source": {"includes": ["tagName"]}
        }
    )
        
    final_context = []
    
    # Check if any positive response has come
    if len(es_response["hits"]["hits"]) > 0:
        answer_doc_list = get_matched_answers(es_response, '../'+client_id+'/chunks.json')
        reranked_response = co.rerank(
            query=query,
            documents=answer_doc_list,
            top_n=8,
            model="rerank-v3.5",
        )
        for idx, res in enumerate(reranked_response.results):
            final_context.append(answer_doc_list[res.index])

    print("\nFinal FAQ list is:", final_context)
    r.jsondel('yubo:conversation_history_' + client_id)
    r.jsonset('yubo:context_'+client_id, '$', final_context)
    r.expire('yubo:context_'+client_id, 300)

    return final_context
    
def get_context(text, bid_id, client_id):
    """
    Get context for a given query using Elasticsearch.
    """
    print("Entered get_intent for client id:", client_id, "With user input:", text)
    
    context_list = []
    
    # Check if the index exists before querying
    if es_client.indices.exists(index=bid_id.lower()):
        text_embeddings = encode_cohere([text])
        context_list = run_es_query(text, text_embeddings, client_id, bid_id)
        
    return context_list

def get_claude_response(query, context_list, client_id):
    """
    Get response from Claude LLM using the context and user query.
    
    Args:
        query (str): The user's question
        context_list (list): List of context documents retrieved from Elasticsearch
        client_id (str): The client identifier
        
    Returns:
        str: Claude's response to the user query based on the context
    """
    print("Getting Claude response for client:", client_id)
    
    # Format the context as a string
    formatted_context = "\n\n".join(context_list)
    
    # Create the prompt for Claude
    prompt = f"""You are an AI assistant helping with RFP (Request for Proposal) questions.
    
CONTEXT:
{formatted_context}

USER QUERY:
{query}

Provide a helpful, accurate, and concise response to the user query based only on the information provided in the context above. If the context doesn't contain relevant information to answer the query, say "I don't have enough information to answer that question." 
DO NOT make up information not present in the context.

Return ONLY the response to the USER QUERY with no additional commentary.
"""

    # Get response from Claude
    response = claude_client.messages.create(
        model="claude-3-7-sonnet-latest",
        max_tokens=2000,
        temperature=0.2,
        system="You are an AI assistant helping with RFP (Request for Proposal) questions. Respond only with information found in the provided context.",
        messages=[
            {"role": "user", "content": prompt}
        ]
    )
    
    print("Claude response generated successfully")
    return response.content[0].text

def chat_with_rfp(query, bid_id, client_id):
    """
    Main function to handle the chat flow with RFP context retrieval and Claude response.
    
    Args:
        query (str): The user's question
        bid_id (str): The bid identifier for Elasticsearch index
        client_id (str): The client identifier
        
    Returns:
        str: Claude's response to the user query
    """
    # Get context from Elasticsearch
    context_list = get_context(query, bid_id, client_id)
    
    # Get response from Claude using the context
    claude_response = get_claude_response(query, context_list, client_id)
    
    return claude_response
