diff --git a/utils/knowledge.py b/utils/knowledge.py new file mode 100644 index 0000000..f76f612 --- /dev/null +++ b/utils/knowledge.py @@ -0,0 +1,46 @@ +from langchain.vectorstores.faiss import FAISS +from langchain.embeddings import HuggingFaceEmbeddings +import sentence_transformers +import numpy as np +import re, os + +__all__ = ["Knowledge"] + +class Knowledge(object): + def __init__(self, knowledge_path="./knowledge", embedding_name='GanymedeNil/text2vec-large-chinese') -> None: + self.embeddings = HuggingFaceEmbeddings(model_name=embedding_name) + self.knowledge = FAISS.load_local(knowledge_path, embeddings=self.embeddings) + # EMBEDDINGS.client = sentence_transformers.SentenceTransformer("/home/wnjxyk/Projects/wenda/model/text2vec-large-chinese", device="cuda") + + def render_index(self, idx, score): + indices = self.knowledge.index_to_docstore_id[idx] + doc = self.knowledge.docstore.search(indices) + meta_content = doc.metadata + return {"title": meta_content['source'], "score": int(score), "content": meta_content["content"]} + + def query_prompt(self, prompt, topk=3, threshold=700): + embedding = self.knowledge.embedding_function(prompt) + scores, indices = self.knowledge.index.search(np.array([embedding], dtype=np.float32), topk) + docs = [] + titles = set() + for j, i in enumerate(indices[0]): + if i == -1: continue + if scores[0][j] > threshold: continue + item = self.render_index(i, scores[0][j]) + if item["title"] in titles: continue + titles.add(item["title"]) + docs.append(item) + return docs + + def get_response(self, output: str) -> str: + first, res = True, "" + for doc in output: + if not first: res += "---\n" + res += doc["content"] + first = False + return res + +# knowledge = Knowledge() +# answer = knowledge.query_prompt("强奸男性犯法吗?") +# print(answer) +# print(knowledge.get_response(answer)) \ No newline at end of file diff --git a/webui.py b/webui.py index 453043a..ff8191c 100644 --- a/webui.py +++ b/webui.py @@ -10,6 +10,7 @@ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, Aut from utils.callbacks import Iteratorize, Stream from utils.prompter import Prompter +from utils.knowledge import Knowledge if torch.cuda.is_available(): device = "cuda" @@ -37,6 +38,7 @@ def main( ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" prompter = Prompter(prompt_template) + knowledge = Knowledge() tokenizer = LlamaTokenizer.from_pretrained(base_model) if device == "cuda": model = LlamaForCausalLM.from_pretrained( @@ -106,6 +108,7 @@ def main( ): input=None prompt = prompter.generate_prompt(instruction, input) + legals = knowledge.query_prompt(instruction) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( @@ -152,7 +155,7 @@ def main( if output[-1] in [tokenizer.eos_token_id]: break - yield prompter.get_response(decoded_output) + yield prompter.get_response(decoded_output), knowledge.get_response(legals) print(decoded_output) return # early return for stream_output @@ -168,7 +171,7 @@ def main( s = generation_output.sequences[0] output = tokenizer.decode(s) print(output) - yield prompter.get_response(output) + yield prompter.get_response(output), knowledge.get_response(legals) gr.Interface( fn=evaluate, @@ -200,6 +203,9 @@ def main( gr.inputs.Textbox( lines=8, label="Output", + ), + gr.inputs.Textbox( + label="Legal Ground", ) ], title="🦙🌲 LaWGPT",