GithubHelp home page GithubHelp logo

Comments (6)

benman1 avatar benman1 commented on July 3, 2024

Hey @mrchaos! Thanks for finding that, that's cool! Do you feel like creating another PR?

from generative_ai_with_langchain.

CSalle avatar CSalle commented on July 3, 2024

Hi,

I've been fighting with this issue for a while (and several more that appeared to me as I fixed the initial ones) and finally managed to code a fix. It is implemented with a later version of Langchain (I have upgraded it during the course to avoid issues with OpenAI's deprecated models). I also implemented some changes to avoid deprecation Warnings. I leave my changes below in case it can help anyone.

The only issue I faced is with moderation chain, which seems to have been deprecated by OpenAI and there is no straightforward way to upgrade it on a Windows environment. The rest should run just fine. In terms of performance, the FLARE chain works way worse for me as it tends to hallucinate more and be more "confused" about the documents in its context.

I can do a PR if you like, but I have not tested this with earlier versions of Langchain

  • langchain==0.1.10
  • langchain-community==0.0.25
  • langchain-core==0.1.28
  • langchain-decorators==0.5.4
  • langchain-experimental==0.0.53
  • langchain-openai==0.0.8
  • langchain-text-splitters==0.0.1

Main changes:

In utils.py:

  • Upgrade libraries import to avoid deprecation warnings
  • Add a .lower() for the extension of documents, so we do not get errors with files named in caps (e.g. "FILE.PDF")
"""Utility functions and constants.

I am having some problems caching the memory and the retrieval. When
I decorate for caching, I get streamlit init errors.
"""
import logging
import pathlib
from typing import Any

from langchain.memory import ConversationBufferMemory
from langchain.schema import Document
from langchain_community.document_loaders import (
    PyPDFLoader,
    TextLoader,
    UnstructuredEPubLoader,
    UnstructuredWordDocumentLoader,
)


def init_memory():
    """Initialize the memory for contextual conversation.

    We are caching this, so it won't be deleted
     every time, we restart the server.
     """
    return ConversationBufferMemory(
        memory_key='chat_history',
        return_messages=True,
        output_key='answer'
    )

MEMORY = init_memory()


class EpubReader(UnstructuredEPubLoader):
    def __init__(self, file_path: str | list[str], **unstructured_kwargs: Any):
        super().__init__(file_path, **unstructured_kwargs, mode="elements", strategy="fast")


class DocumentLoaderException(Exception):
    pass


class DocumentLoader(object):
    """Loads in a document with a supported extension."""
    supported_extensions = {
        ".pdf": PyPDFLoader,
        ".txt": TextLoader,
        ".epub": EpubReader,
        ".docx": UnstructuredWordDocumentLoader,
        ".doc": UnstructuredWordDocumentLoader,
    }


def load_document(temp_filepath: str) -> list[Document]:
    """Load a file and return it as a list of documents.

    Doesn't handle a lot of errors at the moment.
    """
    ext = pathlib.Path(temp_filepath).suffix.lower()
    loader = DocumentLoader.supported_extensions.get(ext)
    if not loader:
        raise DocumentLoaderException(
            f"Invalid extension type {ext}, cannot load this type of file"
        )

    loaded = loader(temp_filepath)
    docs = loaded.load()
    logging.info(docs)
    return docs

In chat_with_documents.py:

  • Upgrade libraries to avoid deprecation warnings
  • Change SimpleSequentialChain for SequentialChain due to errors on the number of inputs required by the former
  • Parameterize the output_key and input variables so that they are aligned with the chain (FLARE vs normal)
"""Chat with retrieval and embeddings."""
import logging
import os
import tempfile

from langchain.chains import (
    ConversationalRetrievalChain,
    FlareChain,
    OpenAIModerationChain,
    SequentialChain,
)
from langchain.chains.base import Chain
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import DocArrayInMemorySearch
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings

from chat_with_retrieval.utils import MEMORY, load_document
from config import set_environment

logging.basicConfig(encoding="utf-8", level=logging.INFO)
LOGGER = logging.getLogger()
set_environment()

# Setup LLM and QA chain; set temperature low to keep hallucinations in check
LLM = ChatOpenAI(
    model_name="gpt-3.5-turbo", temperature=0, streaming=True
)


def configure_retriever(
        docs: list[Document],
        use_compression: bool = False
) -> BaseRetriever:
    """Retriever to use."""
    # Split each document documents:
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
    splits = text_splitter.split_documents(docs)

    # Create embeddings and store in vectordb:
    embeddings = OpenAIEmbeddings()
    # alternatively: HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    # Create vectordb with single call to embedding model for texts:
    vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings)
    retriever = vectordb.as_retriever(
        search_type="mmr", search_kwargs={
            "k": 5,
            "fetch_k": 7,
            "include_metadata": True
        },
    )
    if not use_compression:
        return retriever

    embeddings_filter = EmbeddingsFilter(
        embeddings=embeddings, similarity_threshold=0.2
    )
    return ContextualCompressionRetriever(
        base_compressor=embeddings_filter,
        base_retriever=retriever,
    )


def configure_chain(retriever: BaseRetriever, use_flare: bool = True) -> Chain:
    """Configure chain with a retriever.

    Passing in a max_tokens_limit amount automatically
    truncates the tokens when prompting your llm!
    """
    output_key = 'response' if use_flare else 'answer'
    MEMORY.output_key = output_key
    params = dict(
        llm=LLM,
        retriever=retriever,
        memory=MEMORY,
        verbose=True,
        max_tokens_limit=4000,
    )
    if use_flare:
        # different set of parameters and init
        # unfortunately, have to use "protected" class
        return FlareChain.from_llm(
            **params
        )
    return ConversationalRetrievalChain.from_llm(
        **params
    )


def configure_retrieval_chain(
        uploaded_files,
        use_compression: bool = False,
        use_flare: bool = False,
        use_moderation: bool = False
) -> Chain:
    """Read documents, configure retriever, and the chain."""
    docs = []
    temp_dir = tempfile.TemporaryDirectory()
    for file in uploaded_files:
        temp_filepath = os.path.join(temp_dir.name, file.name)
        with open(temp_filepath, "wb") as f:
            f.write(file.getvalue())
        docs.extend(load_document(temp_filepath))

    retriever = configure_retriever(docs=docs, use_compression=use_compression)
    chain = configure_chain(retriever=retriever, use_flare=use_flare)
    if not use_moderation:
        return chain

    input_variables = ["user_input"] if use_flare else ["chat_history", "question"]
    moderation_input = "response" if use_flare else "answer"
    moderation_chain = OpenAIModerationChain(input_key=moderation_input)
    return SequentialChain(chains=[chain, moderation_chain], 
                           input_variables=input_variables)

 

In app.py:

  • Parameterize the params provided to the chain depending on use_flare
  • Change .run() to .invoke() (and the respective inputs) to avoid deprecation warnings
  • Extract the actual answer from the response of the chain (.invoke() returns and object with question, context and answer)
"""Document loading functionality.

Run like this:
> PYTHONPATH=. streamlit run chat_with_retrieval/chat_with_documents.py
"""
import logging

import streamlit as st
from streamlit.external.langchain import StreamlitCallbackHandler

from chat_with_retrieval.chat_with_documents import configure_retrieval_chain
from chat_with_retrieval.utils import MEMORY, DocumentLoader

logging.basicConfig(encoding="utf-8", level=logging.INFO)
LOGGER = logging.getLogger()

st.set_page_config(page_title="LangChain: Chat with Documents", page_icon="🦜")
st.title("🦜 LangChain: Chat with Documents")


uploaded_files = st.sidebar.file_uploader(
    label="Upload files",
    type=list(DocumentLoader.supported_extensions.keys()),
    accept_multiple_files=True
)
if not uploaded_files:
    st.info("Please upload documents to continue.")
    st.stop()

# use compression by default:
use_compression = st.checkbox("compression", value=False)
use_flare = st.checkbox("flare", value=False)
use_moderation = st.checkbox("moderation", value=False)

CONV_CHAIN = configure_retrieval_chain(
    uploaded_files,
    use_compression=use_compression,
    use_flare=use_flare,
    use_moderation=use_moderation
)

if st.sidebar.button("Clear message history"):
    MEMORY.chat_memory.clear()

avatars = {"human": "user", "ai": "assistant"}

if  len(MEMORY.chat_memory.messages) == 0:
    st.chat_message("assistant").markdown("Ask me anything!")

for msg in MEMORY.chat_memory.messages:
    st.chat_message(avatars[msg.type]).write(msg.content)

assistant = st.chat_message("assistant")
if user_query := st.chat_input(placeholder="Give me 3 keywords for what you have right now"):
    st.chat_message("user").write(user_query)
    container = st.empty()
    stream_handler = StreamlitCallbackHandler(container)
    with st.chat_message("assistant"):
        params = {
            "question": user_query,
            "chat_history": MEMORY.chat_memory.messages
        }
        if use_flare:
            params = {"user_input": user_query}
        config = {'callbacks': [stream_handler]}
        response = CONV_CHAIN.invoke(input=params, config=config)
        output_key = 'response' if use_flare else 'answer'

        # Display the response from the chatbot
        if response:
            container.markdown(response[output_key])

from generative_ai_with_langchain.

benman1 avatar benman1 commented on July 3, 2024

@CSalle that's really cool! (sorry, for only getting back to you now). I'll test this on langchain 0.0.284 soon.

I've started another branch that's called softupdate and runs on a newer version of LangChain (0.1.13). I've only updated a few notebooks so far (chapter 3), but if you like you could create a PR against that branch.

from generative_ai_with_langchain.

CSalle avatar CSalle commented on July 3, 2024

Done!

from generative_ai_with_langchain.

benman1 avatar benman1 commented on July 3, 2024

Thanks, @CSalle and @mrchaos! Closing this now.

from generative_ai_with_langchain.

benman1 avatar benman1 commented on July 3, 2024

@mrchaos and @CSalle it took me a while, but I've finally pushed these changes to main, starting with 2b886bf. It's also on the softupdate branch. Thanks both of you, that was really helpful! Please let me know if anything else needs to change,

from generative_ai_with_langchain.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.