본문 바로가기
Python

Ollama, RAG, Chroma, Llama 3.2 활용 - 주식 종목 목표가 리포트 하기

by Playdev 2024. 11. 13.
728x90

 

ollama 의 샘플 langchain-python-rag-document을 활용하여. 온라인 pdf 를 임베딩 하고 chromadb 에 넣어 PDF 기반으로 질문을 하도록 해보았다.

 

아주 간단한 예시로, 네이버 증권에서 제공하는 삼성SDI 종목의 종목분석 리포트를 대상으로 하였으며

chromadb 는 로컬 환경에 따라 18000 포트로 변경하였다.

 

- llama 3.2 ko bllossom 3b (https://huggingface.co/Bllossom/llama-3.2-Korean-Bllossom-3B) 모델 사용

- 프롬프트 응답시 json 포맷 제한

- chroma db 에 데이터 임베딩

- 임베딩 데이터 기반 질의 응답

 

 

질문 : 삼성SDI 목표 주가는?

답변 : 540000

실행 결과

 

from langchain_community.vectorstores import Chroma
from langchain_core.prompts import PromptTemplate
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA
import sys
import os

from langchain_community.document_loaders import PyPDFLoader
from langchain_ollama import OllamaEmbeddings, OllamaLLM

import chromadb
import json


class SuppressStdout:
    def __enter__(self):
        self._original_stdout = sys.stdout
        self._original_stderr = sys.stderr
        sys.stdout = open(os.devnull, 'w')
        sys.stderr = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout
        sys.stderr = self._original_stderr


model = "llama-3.2-ko-bllossom-3b"
collection_name = 'pdf_collection'
embedding_function = OllamaEmbeddings(model=model)

chroma_client = chromadb.HttpClient(host='127.0.0.1', port=18000)
collection = chroma_client.get_or_create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})

research_companies = [
    {'id': 1, 'url': 'https://stock.pstatic.net/stock-research/company/66/20241101_company_811435000.pdf'},
    {'id': 2, 'url': 'https://stock.pstatic.net/stock-research/company/62/20241031_company_338638000.pdf'},
    {'id': 3, 'url': 'https://stock.pstatic.net/stock-research/company/57/20241031_company_741548000.pdf'},
    {'id': 4, 'url': 'https://stock.pstatic.net/stock-research/company/61/20241031_company_230386000.pdf'},
    {'id': 5, 'url': 'https://stock.pstatic.net/stock-research/company/31/20241031_company_977753000.pdf'},
    {'id': 6, 'url': 'https://stock.pstatic.net/stock-research/company/39/20241031_company_852443000.pdf'},
    {'id': 7, 'url': 'https://stock.pstatic.net/stock-research/company/29/20241031_company_948504000.pdf'},
    {'id': 8, 'url': 'https://stock.pstatic.net/stock-research/company/16/20241015_company_754737000.pdf'},
    {'id': 9, 'url': 'https://stock.pstatic.net/stock-research/company/66/20241011_company_916059000.pdf'},
    {'id': 10, 'url': 'https://stock.pstatic.net/stock-research/company/18/20241010_company_668887000.pdf'},
]

from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)

for research in research_companies:
    id = research['id']
    url = research['url']

    loader = PyPDFLoader(url)
    data = loader.load()

    all_splits = text_splitter.split_documents(data)

    ids = [f"{id}_{idx}" for idx in range(len(all_splits))]
    documents = [split.page_content for split in all_splits]
    embedding = embedding_function.embed_documents(documents)

    collection.upsert(ids=ids,
                      documents=documents,
                      embeddings=embedding)

with SuppressStdout():
    vectorstore = Chroma(
        client=chroma_client,
        embedding_function=embedding_function,
        collection_name=collection_name
    )

if __name__ == '__main__':
    query = '삼성SDI 목표 주가는?'

    template = """다음 정보를 사용하여 질문에 답변하세요.
    {context}
    
    질문: {question}
    YOU MUST STRICTLY ADHERE TO THE GIVEN FORMAT.
 
    DO NOT INCLUDE ANY KEYWORDS OTHER THAN THE GIVEN KEYWORD.
    format your response so that is follows a JSON format, for example:
    //
    {{
        "targetPrice" : 100000
    }}

    your response should only contain the JSON format and nothing else.
    """
    QA_CHAIN_PROMPT = PromptTemplate(
        input_variables=["context", "question"],
        template=template,
    )

    # llm = OllamaLLM(model=model, format='json', callbacks=CallbackManager([StreamingStdOutCallbackHandler()]))
    llm = OllamaLLM(model=model, format='json', num_gpu=1, top_k=3)
    qa_chain = RetrievalQA.from_chain_type(
        llm,
        retriever=vectorstore.as_retriever(),
        chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
    )

    result = qa_chain.invoke({"query": query})

    try:
        result = result['result'].strip()
        target_price = json.loads(result).get("targetPrice")
        print(f'목표가 : {target_price}')
    except (json.JSONDecodeError, KeyError) as e:
        print("JSON 형식으로 변환 실패:", e)

 

 

데이터를 임베딩 하는 프로세스 (배치) 와 질의 응답하는 프로세스를 분리하면

하나의 서비스로도 발전시킬 수 있을 것 같다.

 

728x90