Merge pull request #72 from WNJXYK/main

[ENH] Add knowledge base for LawGPT
This commit is contained in:
喂你脚下有坑
2023-05-28 13:37:42 +08:00
committed by GitHub
2 changed files with 54 additions and 2 deletions

46
utils/knowledge.py Normal file
View File

@ -0,0 +1,46 @@
from langchain.vectorstores.faiss import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
import sentence_transformers
import numpy as np
import re, os
__all__ = ["Knowledge"]
class Knowledge(object):
def __init__(self, knowledge_path="./knowledge", embedding_name='GanymedeNil/text2vec-large-chinese') -> None:
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_name)
self.knowledge = FAISS.load_local(knowledge_path, embeddings=self.embeddings)
# EMBEDDINGS.client = sentence_transformers.SentenceTransformer("/home/wnjxyk/Projects/wenda/model/text2vec-large-chinese", device="cuda")
def render_index(self, idx, score):
indices = self.knowledge.index_to_docstore_id[idx]
doc = self.knowledge.docstore.search(indices)
meta_content = doc.metadata
return {"title": meta_content['source'], "score": int(score), "content": meta_content["content"]}
def query_prompt(self, prompt, topk=3, threshold=700):
embedding = self.knowledge.embedding_function(prompt)
scores, indices = self.knowledge.index.search(np.array([embedding], dtype=np.float32), topk)
docs = []
titles = set()
for j, i in enumerate(indices[0]):
if i == -1: continue
if scores[0][j] > threshold: continue
item = self.render_index(i, scores[0][j])
if item["title"] in titles: continue
titles.add(item["title"])
docs.append(item)
return docs
def get_response(self, output: str) -> str:
first, res = True, ""
for doc in output:
if not first: res += "---\n"
res += doc["content"]
first = False
return res
# knowledge = Knowledge()
# answer = knowledge.query_prompt("强奸男性犯法吗?")
# print(answer)
# print(knowledge.get_response(answer))

View File

@ -10,6 +10,7 @@ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, Aut
from utils.callbacks import Iteratorize, Stream from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter from utils.prompter import Prompter
from utils.knowledge import Knowledge
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
@ -37,6 +38,7 @@ def main(
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
prompter = Prompter(prompt_template) prompter = Prompter(prompt_template)
knowledge = Knowledge()
tokenizer = LlamaTokenizer.from_pretrained(base_model) tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda": if device == "cuda":
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
@ -106,6 +108,7 @@ def main(
): ):
input=None input=None
prompt = prompter.generate_prompt(instruction, input) prompt = prompter.generate_prompt(instruction, input)
legals = knowledge.query_prompt(instruction)
inputs = tokenizer(prompt, return_tensors="pt") inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device) input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig( generation_config = GenerationConfig(
@ -152,7 +155,7 @@ def main(
if output[-1] in [tokenizer.eos_token_id]: if output[-1] in [tokenizer.eos_token_id]:
break break
yield prompter.get_response(decoded_output) yield prompter.get_response(decoded_output), knowledge.get_response(legals)
print(decoded_output) print(decoded_output)
return # early return for stream_output return # early return for stream_output
@ -168,7 +171,7 @@ def main(
s = generation_output.sequences[0] s = generation_output.sequences[0]
output = tokenizer.decode(s) output = tokenizer.decode(s)
print(output) print(output)
yield prompter.get_response(output) yield prompter.get_response(output), knowledge.get_response(legals)
gr.Interface( gr.Interface(
fn=evaluate, fn=evaluate,
@ -200,6 +203,9 @@ def main(
gr.inputs.Textbox( gr.inputs.Textbox(
lines=8, lines=8,
label="Output", label="Output",
),
gr.inputs.Textbox(
label="Legal Ground",
) )
], ],
title="🦙🌲 LaWGPT", title="🦙🌲 LaWGPT",