Files
cheetah/Cheetah/ConversationAnalyzer.swift
2023-03-26 18:42:42 -04:00

163 lines
5.5 KiB
Swift

import LibWhisper
import Combine
enum ContextKey: String {
case transcript
case question
case answerInCode
case answer
case previousAnswer
case highlightedAnswer
case codeAnswer
case browserCode
case browserLogs
}
typealias AnalysisContext = [ContextKey: String]
extension AnalysisContext {
var answerInCode: Bool {
return self[.answerInCode]?.first?.lowercased() == "y"
}
}
enum AnalysisError: Error {
case missingRequiredContextKey(ContextKey)
}
extension PromptGenerator {
func extractQuestion(context: AnalysisContext) throws -> ModelInput? {
if let transcript = context[.transcript] {
return extractQuestion(transcript: transcript)
} else {
throw AnalysisError.missingRequiredContextKey(.transcript)
}
}
func answerQuestion(context: AnalysisContext) throws -> ModelInput? {
guard let question = context[.question] else {
throw AnalysisError.missingRequiredContextKey(.question)
}
if context.answerInCode {
return nil
} else if let answer = context[.previousAnswer] {
return answerQuestion(question, previousAnswer: answer)
} else if let answer = context[.highlightedAnswer] {
return answerQuestion(question, highlightedAnswer: answer)
} else {
return answerQuestion(question)
}
}
func writeCode(context: AnalysisContext) -> ModelInput? {
if context.answerInCode, let question = context[.question] {
return writeCode(task: question)
} else {
return nil
}
}
func analyzeBrowserCode(context: AnalysisContext) -> ModelInput? {
if let code = context[.browserCode], let logs = context[.browserLogs] {
return analyzeBrowserCode(code, logs: logs, task: context[.question])
} else {
return nil
}
}
}
extension ContextKey {
var `set`: (String, inout AnalysisContext) -> () {
return { output, context in
context[self] = output
}
}
func extract(using regexArray: Regex<(Substring, answer: Substring)>...) -> (String, inout AnalysisContext) -> () {
return { output, context in
for regex in regexArray {
if let match = output.firstMatch(of: regex) {
context[self] = String(match.answer)
}
}
}
}
}
let extractQuestion: (String, inout AnalysisContext) -> () = { output, context in
let regex = /Extracted question: (?<question>[^\n]+)(?:\nAnswer in code: (?<answerInCode>Yes|No))?/.ignoresCase()
if let match = output.firstMatch(of: regex) {
context[.question] = String(match.question)
if let answerInCode = match.answerInCode {
context[.answerInCode] = String(answerInCode)
}
}
}
let finalAnswerRegex = /Final answer:\n(?<answer>[-].+$)/.dotMatchesNewlines()
let answerOnlyRegex = /(?<answer>[-].+$)/.dotMatchesNewlines()
class ConversationAnalyzer {
let stream: WhisperStream
let generator: PromptGenerator
let executor: OpenAIExecutor
init(stream: WhisperStream, generator: PromptGenerator, executor: OpenAIExecutor) {
self.stream = stream
self.generator = generator
self.executor = executor
}
var context = [ContextKey: String]()
func answer(refine: Bool = false, selection: Range<String.Index>? = nil) async throws {
let chain = PromptChain(
generator: generator.extractQuestion,
updateContext: extractQuestion,
maxTokens: 250,
children: [
Prompt(generator: generator.answerQuestion,
updateContext: ContextKey.answer.extract(using: finalAnswerRegex, answerOnlyRegex),
maxTokens: 500),
Prompt(generator: generator.writeCode,
updateContext: ContextKey.codeAnswer.set,
maxTokens: 1000),
])
var newContext: AnalysisContext = [
.transcript: String(stream.segments.text)
]
if refine, let previousAnswer = context[.answer] {
if let selection = selection {
let highlightedAnswer = previousAnswer[..<selection.lowerBound] + " [start highlighted text] " + previousAnswer[selection] + " [end highlighted text] " + previousAnswer[selection.upperBound...]
newContext[.highlightedAnswer] = String(highlightedAnswer)
} else {
newContext[.previousAnswer] = previousAnswer
}
}
context = try await executor.execute(chain: chain, context: newContext)
}
func analyzeCode(extensionState: BrowserExtensionState) async throws {
let newContext: AnalysisContext = [
.transcript: String(stream.segments.text),
.browserCode: extensionState.codeDescription,
.browserLogs: extensionState.logsDescription
]
let chain = PromptChain(
generator: generator.extractQuestion,
updateContext: extractQuestion,
maxTokens: 250,
children: [
Prompt(generator: generator.analyzeBrowserCode,
updateContext: ContextKey.answer.set,
maxTokens: 500)
])
context = try await executor.execute(chain: chain, context: newContext)
}
}