Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f5de57f11 | |||
| cb6a4a9f6b | |||
| dc849bd282 | |||
| 39eed4febe |
46
utils/knowledge.py
Normal file
46
utils/knowledge.py
Normal file
@ -0,0 +1,46 @@
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
import sentence_transformers
|
||||
import numpy as np
|
||||
import re, os
|
||||
|
||||
__all__ = ["Knowledge"]
|
||||
|
||||
class Knowledge(object):
|
||||
def __init__(self, knowledge_path="./knowledge", embedding_name='GanymedeNil/text2vec-large-chinese') -> None:
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_name)
|
||||
self.knowledge = FAISS.load_local(knowledge_path, embeddings=self.embeddings)
|
||||
# EMBEDDINGS.client = sentence_transformers.SentenceTransformer("/home/wnjxyk/Projects/wenda/model/text2vec-large-chinese", device="cuda")
|
||||
|
||||
def render_index(self, idx, score):
|
||||
indices = self.knowledge.index_to_docstore_id[idx]
|
||||
doc = self.knowledge.docstore.search(indices)
|
||||
meta_content = doc.metadata
|
||||
return {"title": meta_content['source'], "score": int(score), "content": meta_content["content"]}
|
||||
|
||||
def query_prompt(self, prompt, topk=3, threshold=700):
|
||||
embedding = self.knowledge.embedding_function(prompt)
|
||||
scores, indices = self.knowledge.index.search(np.array([embedding], dtype=np.float32), topk)
|
||||
docs = []
|
||||
titles = set()
|
||||
for j, i in enumerate(indices[0]):
|
||||
if i == -1: continue
|
||||
if scores[0][j] > threshold: continue
|
||||
item = self.render_index(i, scores[0][j])
|
||||
if item["title"] in titles: continue
|
||||
titles.add(item["title"])
|
||||
docs.append(item)
|
||||
return docs
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
first, res = True, ""
|
||||
for doc in output:
|
||||
if not first: res += "---\n"
|
||||
res += doc["content"]
|
||||
first = False
|
||||
return res
|
||||
|
||||
# knowledge = Knowledge()
|
||||
# answer = knowledge.query_prompt("强奸男性犯法吗?")
|
||||
# print(answer)
|
||||
# print(knowledge.get_response(answer))
|
||||
10
webui.py
10
webui.py
@ -10,6 +10,7 @@ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, Aut
|
||||
|
||||
from utils.callbacks import Iteratorize, Stream
|
||||
from utils.prompter import Prompter
|
||||
from utils.knowledge import Knowledge
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
@ -37,6 +38,7 @@ def main(
|
||||
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
|
||||
|
||||
prompter = Prompter(prompt_template)
|
||||
knowledge = Knowledge()
|
||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||
if device == "cuda":
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
@ -106,6 +108,7 @@ def main(
|
||||
):
|
||||
input=None
|
||||
prompt = prompter.generate_prompt(instruction, input)
|
||||
legals = knowledge.query_prompt(instruction)
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
generation_config = GenerationConfig(
|
||||
@ -152,7 +155,7 @@ def main(
|
||||
if output[-1] in [tokenizer.eos_token_id]:
|
||||
break
|
||||
|
||||
yield prompter.get_response(decoded_output)
|
||||
yield prompter.get_response(decoded_output), knowledge.get_response(legals)
|
||||
print(decoded_output)
|
||||
return # early return for stream_output
|
||||
|
||||
@ -168,7 +171,7 @@ def main(
|
||||
s = generation_output.sequences[0]
|
||||
output = tokenizer.decode(s)
|
||||
print(output)
|
||||
yield prompter.get_response(output)
|
||||
yield prompter.get_response(output), knowledge.get_response(legals)
|
||||
|
||||
gr.Interface(
|
||||
fn=evaluate,
|
||||
@ -200,6 +203,9 @@ def main(
|
||||
gr.inputs.Textbox(
|
||||
lines=8,
|
||||
label="Output",
|
||||
),
|
||||
gr.inputs.Textbox(
|
||||
label="Legal Ground",
|
||||
)
|
||||
],
|
||||
title="🦙🌲 LaWGPT",
|
||||
|
||||
Reference in New Issue
Block a user