Files
realtime_voice_bot/src/services/ollama-llm.ts
2026-05-03 01:56:09 +09:00

549 lines
15 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import type { AppConfig } from "../config.js";
import type { Logger } from "../logger.js";
import { loadPrompt } from "../prompt-loader.js";
import { webFetch, webSearch } from "./web-tools.js";
interface OllamaChatMessage {
role: "system" | "user" | "assistant";
content: string;
tool_calls?: OllamaToolCall[];
}
interface OllamaChatResponse {
message?: {
content?: string;
tool_calls?: OllamaToolCall[];
};
}
interface OllamaToolCall {
type: "function";
function: {
name: string;
arguments: Record<string, unknown>;
};
}
interface OllamaToolDefinition {
type: "function";
function: {
name: string;
description: string;
parameters: {
type: "object";
required?: string[];
properties: Record<string, unknown>;
};
};
}
interface OllamaToolResultMessage {
role: "tool";
tool_name: string;
content: string;
}
interface GenerateReplyOptions {
onProgress?: (message: string) => void;
}
export interface ReplyAssessment {
shouldReply: boolean;
likelyNeedsLookup: boolean;
reason: string;
}
const ASSISTANT_PROMPT = loadPrompt("assistant.md");
const REPLY_GATE_PROMPT = loadPrompt("reply-gate.md");
const REWRITE_KOREAN_PROMPT = loadPrompt("rewrite-korean.md");
const TOOL_DEFINITIONS: OllamaToolDefinition[] = [
{
type: "function",
function: {
name: "get_current_time",
description: "현재 시스템 시간을 Asia/Seoul 기준 ISO 문자열과 사람이 읽기 쉬운 형식으로 반환한다.",
parameters: {
type: "object",
properties: {},
},
},
},
{
type: "function",
function: {
name: "get_runtime_settings",
description: "현재 로컬 LLM 및 STT 실행 설정의 핵심 값만 반환한다.",
parameters: {
type: "object",
properties: {},
},
},
},
{
type: "function",
function: {
name: "list_project_commands",
description: "현재 프로젝트에서 사용 가능한 주요 bun 스크립트 명령 목록을 반환한다.",
parameters: {
type: "object",
properties: {},
},
},
},
{
type: "function",
function: {
name: "evaluate_math",
description: "간단한 산술식을 정확히 계산한다. 숫자, 공백, 소수점, 괄호, + - * / % 만 허용한다.",
parameters: {
type: "object",
required: ["expression"],
properties: {
expression: {
type: "string",
description: "예: (11434+12341)*412",
},
},
},
},
},
{
type: "function",
function: {
name: "web_search",
description: "웹 검색 결과 제목, URL, 요약을 가져온다. 최신 정보, 뉴스, 사실 확인이 필요할 때만 사용한다.",
parameters: {
type: "object",
required: ["query"],
properties: {
query: {
type: "string",
description: "검색어",
},
max_results: {
type: "number",
description: "가져올 최대 결과 수. 보통 3~5",
},
},
},
},
},
{
type: "function",
function: {
name: "fetch_url",
description: "주어진 URL의 페이지 제목과 본문 텍스트를 읽어온다. 검색 결과 상세 확인에 사용한다.",
parameters: {
type: "object",
required: ["url"],
properties: {
url: {
type: "string",
description: "http 또는 https URL",
},
max_chars: {
type: "number",
description: "본문에서 가져올 최대 글자 수",
},
},
},
},
},
];
export class OllamaLlmService {
private history: OllamaChatMessage[] = [];
constructor(
private readonly config: AppConfig,
private readonly logger: Logger,
) {}
async warmup(): Promise<void> {
const reply = await this.chat(
[
{ role: "system", content: ASSISTANT_PROMPT },
{ role: "user", content: "준비 상태 확인입니다. 한 단어로만 답하세요." },
],
);
this.logger.info("LLM warmup finished", { model: this.config.OLLAMA_MODEL, reply: reply.content });
}
async assessReplyNeed(userText: string): Promise<ReplyAssessment> {
const heuristic = this.assessReplyNeedHeuristically(userText);
if (heuristic) {
return heuristic;
}
const reply = await this.chat([
{ role: "system", content: REPLY_GATE_PROMPT },
{ role: "user", content: userText },
], { enableTools: false });
const parsed = this.parseAssessment(reply.content);
if (parsed) {
return parsed;
}
return {
shouldReply: true,
likelyNeedsLookup: this.mightNeedLookup(userText),
reason: "fallback",
};
}
async generateReply(userText: string, options?: GenerateReplyOptions): Promise<string> {
const messages: Array<OllamaChatMessage | OllamaToolResultMessage> = [
{ role: "system", content: ASSISTANT_PROMPT },
...this.history,
{ role: "user", content: userText },
];
const rawReply = await this.runAgentLoop(messages, options);
const reply = await this.repairReplyLanguageIfNeeded(rawReply, userText);
this.history.push({ role: "user", content: userText });
this.history.push({ role: "assistant", content: reply });
this.trimHistory();
return reply;
}
resetConversation(): void {
this.history = [];
}
private trimHistory(): void {
const maxMessages = this.config.MAX_CONVERSATION_TURNS * 2;
if (this.history.length <= maxMessages) {
return;
}
this.history = this.history.slice(-maxMessages);
}
private async runAgentLoop(
messages: Array<OllamaChatMessage | OllamaToolResultMessage>,
options?: GenerateReplyOptions,
): Promise<string> {
let progressEmitted = false;
for (let step = 0; step < 6; step += 1) {
const response = await this.chat(messages, { enableTools: true });
const toolCalls = response.toolCalls ?? [];
messages.push({
role: "assistant",
content: response.content,
tool_calls: toolCalls.length > 0 ? toolCalls : undefined,
});
if (toolCalls.length === 0) {
return response.content;
}
for (const call of toolCalls) {
if (!progressEmitted) {
const progressMessage = this.getProgressMessage(call.function.name);
if (progressMessage) {
options?.onProgress?.(progressMessage);
progressEmitted = true;
}
}
const result = await this.executeTool(call);
this.logger.info("LLM tool call", {
name: call.function.name,
arguments: call.function.arguments,
result,
});
messages.push({
role: "tool",
tool_name: call.function.name,
content: result,
});
}
}
throw new Error("도구 호출 루프가 제한 횟수를 넘었습니다.");
}
private async chat(
messages: Array<OllamaChatMessage | OllamaToolResultMessage>,
options?: { enableTools: boolean },
): Promise<{ content: string; toolCalls: OllamaToolCall[] }> {
const response = await fetch(`${this.config.OLLAMA_BASE_URL}/api/chat`, {
method: "POST",
headers: {
"content-type": "application/json",
},
body: JSON.stringify({
model: this.config.OLLAMA_MODEL,
messages,
tools: options?.enableTools ? TOOL_DEFINITIONS : undefined,
stream: false,
think: false,
keep_alive: this.config.OLLAMA_KEEP_ALIVE,
}),
});
if (!response.ok) {
const body = await response.text();
throw new Error(`Ollama API ${response.status}: ${body}`);
}
const payload = (await response.json()) as OllamaChatResponse;
const content = payload.message?.content?.trim() ?? "";
const toolCalls = payload.message?.tool_calls ?? [];
if (!content && toolCalls.length === 0) {
throw new Error("Ollama 응답에 message.content 와 tool_calls 가 모두 없습니다.");
}
return {
content,
toolCalls,
};
}
private async executeTool(call: OllamaToolCall): Promise<string> {
switch (call.function.name) {
case "get_current_time":
return JSON.stringify(this.getCurrentTime());
case "get_runtime_settings":
return JSON.stringify(this.getRuntimeSettings());
case "list_project_commands":
return JSON.stringify(this.listProjectCommands());
case "evaluate_math":
return JSON.stringify({
expression: this.getStringArg(call.function.arguments, "expression"),
result: this.evaluateMath(this.getStringArg(call.function.arguments, "expression")),
});
case "web_search":
return JSON.stringify(
await webSearch(
this.getStringArg(call.function.arguments, "query"),
Math.min(5, Math.max(1, Math.trunc(this.getNumberArg(call.function.arguments, "max_results", 4)))),
),
);
case "fetch_url":
return JSON.stringify(
await webFetch(
this.getStringArg(call.function.arguments, "url"),
Math.min(10000, Math.max(1000, Math.trunc(this.getNumberArg(call.function.arguments, "max_chars", 6000)))),
),
);
default:
return JSON.stringify({
error: `unknown tool: ${call.function.name}`,
});
}
}
private getCurrentTime(): { timezone: string; iso: string; local: string } {
const now = new Date();
return {
timezone: "Asia/Seoul",
iso: now.toISOString(),
local: new Intl.DateTimeFormat("ko-KR", {
timeZone: "Asia/Seoul",
dateStyle: "full",
timeStyle: "long",
}).format(now),
};
}
private getRuntimeSettings(): Record<string, unknown> {
return {
ollama_base_url: this.config.OLLAMA_BASE_URL,
ollama_model: this.config.OLLAMA_MODEL,
ollama_keep_alive: this.config.OLLAMA_KEEP_ALIVE,
max_conversation_turns: this.config.MAX_CONVERSATION_TURNS,
whisper_model: this.config.WHISPER_MODEL,
whisper_language: this.config.WHISPER_LANGUAGE,
whisper_device: this.config.WHISPER_DEVICE,
whisper_compute_type: this.config.WHISPER_COMPUTE_TYPE,
whisper_beam_size: this.config.WHISPER_BEAM_SIZE,
audio_source: this.config.AUDIO_SOURCE ?? null,
debug: this.config.DEBUG,
};
}
private listProjectCommands(): { commands: string[] } {
return {
commands: [
"bun run setup",
"bun run setup:stt",
"bun run setup:llm",
"bun run setup:tts",
"bun run devices",
"bun run test:stt",
"bun run test:sttllm",
"bun run test:llm",
"bun run test:tts -- \"안녕하세요\"",
],
};
}
private getStringArg(args: Record<string, unknown>, name: string): string {
const value = args[name];
if (typeof value !== "string" || value.trim().length === 0) {
throw new Error(`도구 인자 ${name} 가 비어 있습니다.`);
}
return value.trim();
}
private evaluateMath(expression: string): number {
if (!/^[0-9+\-*/%().\s]+$/.test(expression)) {
throw new Error("허용되지 않은 문자가 포함된 산술식입니다.");
}
const result = Function(`"use strict"; return (${expression});`)();
if (typeof result !== "number" || !Number.isFinite(result)) {
throw new Error("산술식 계산 결과가 유효하지 않습니다.");
}
return result;
}
private getNumberArg(args: Record<string, unknown>, name: string, fallback: number): number {
const value = args[name];
if (typeof value === "number" && Number.isFinite(value)) {
return value;
}
if (typeof value === "string") {
const parsed = Number(value);
if (Number.isFinite(parsed)) {
return parsed;
}
}
return fallback;
}
private async repairReplyLanguageIfNeeded(reply: string, userText: string): Promise<string> {
if (!this.needsLanguageRepair(reply)) {
return reply;
}
this.logger.warn("Reply language repair triggered", {
reply,
analysis: this.analyzeScriptUsage(reply),
});
const repaired = await this.chat(
[
{
role: "system",
content: REWRITE_KOREAN_PROMPT,
},
{
role: "user",
content: `원문 질문: ${userText}\n기존 답변: ${reply}`,
},
],
{ enableTools: false },
);
const normalized = repaired.content.trim();
if (!normalized) {
return reply;
}
return normalized;
}
private needsLanguageRepair(text: string): boolean {
const analysis = this.analyzeScriptUsage(text);
if (analysis.otherLetters > 0) {
return true;
}
if (analysis.hangul === 0 && analysis.latin > 0) {
return true;
}
return false;
}
private analyzeScriptUsage(text: string): { hangul: number; latin: number; otherLetters: number } {
let hangul = 0;
let latin = 0;
let otherLetters = 0;
for (const char of text) {
if (!/\p{Letter}/u.test(char)) {
continue;
}
if (/\p{Script=Hangul}/u.test(char)) {
hangul += 1;
continue;
}
if (/\p{Script=Latin}/u.test(char)) {
latin += 1;
continue;
}
otherLetters += 1;
}
return { hangul, latin, otherLetters };
}
private getProgressMessage(toolName: string): string | null {
switch (toolName) {
case "web_search":
case "fetch_url":
return "검색해볼게요.";
default:
return null;
}
}
private parseAssessment(content: string): ReplyAssessment | null {
const match = content.match(/\{[\s\S]*\}/);
if (!match) {
return null;
}
try {
const parsed = JSON.parse(match[0]) as Record<string, unknown>;
return {
shouldReply: parsed.should_reply === true || parsed.shouldReply === true,
likelyNeedsLookup: parsed.likely_needs_lookup === true || parsed.likelyNeedsLookup === true,
reason: typeof parsed.reason === "string" ? parsed.reason : "parsed",
};
} catch {
return null;
}
}
private assessReplyNeedHeuristically(userText: string): ReplyAssessment | null {
const normalized = userText.trim();
if (!normalized) {
return {
shouldReply: false,
likelyNeedsLookup: false,
reason: "empty",
};
}
if (/^(+|+|+|+|+|+|+|+|+|+|+|+|+|+|+|\.?)$/u.test(normalized)) {
return {
shouldReply: false,
likelyNeedsLookup: false,
reason: "filler",
};
}
if (normalized.length <= 2 && !/[?]/.test(normalized)) {
return {
shouldReply: false,
likelyNeedsLookup: false,
reason: "too_short",
};
}
return null;
}
private mightNeedLookup(text: string): boolean {
return /(||||||||||||)/u.test(text);
}
}