diff --git a/src/lib/ChatRequest.svelte b/src/lib/ChatRequest.svelte index dede618..ba1ec35 100644 --- a/src/lib/ChatRequest.svelte +++ b/src/lib/ChatRequest.svelte @@ -21,6 +21,7 @@ export class ChatRequest { updating: boolean|number = false updatingMessage: string = '' controller:AbortController + providerData: Record = {} setChat (chat: Chat) { this.chat = chat diff --git a/src/lib/ChatSettingsModal.svelte b/src/lib/ChatSettingsModal.svelte index b35b30f..be98ff9 100644 --- a/src/lib/ChatSettingsModal.svelte +++ b/src/lib/ChatSettingsModal.svelte @@ -1,6 +1,6 @@ \ No newline at end of file diff --git a/src/lib/providers/openai/models.svelte b/src/lib/providers/openai/models.svelte index 24f9fad..c46fbe5 100644 --- a/src/lib/providers/openai/models.svelte +++ b/src/lib/providers/openai/models.svelte @@ -20,7 +20,8 @@ const hiddenSettings = { assistantMessageEnd: true, systemMessageStart: true, systemMessageEnd: true, - repetitionPenalty: true + repetitionPenalty: true, + holdSocket: true // leadPrompt: true } diff --git a/src/lib/providers/petals/request.svelte b/src/lib/providers/petals/request.svelte index bd15319..53a8984 100644 --- a/src/lib/providers/petals/request.svelte +++ b/src/lib/providers/petals/request.svelte @@ -5,6 +5,29 @@ import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte' import { getModelMaxTokens } from '../../Stats.svelte' import { updateMessages } from '../../Storage.svelte' + import { escapeRegex } from '../../Util.svelte' + +const levenshteinDistance = (str1 = '', str2 = '') => { + const track = Array(str2.length + 1).fill(null).map(() => + Array(str1.length + 1).fill(null)) + for (let i = 0; i <= str1.length; i += 1) { + track[0][i] = i + } + for (let j = 0; j <= str2.length; j += 1) { + track[j][0] = j + } + for (let j = 1; j <= str2.length; j += 1) { + for (let i = 1; i <= str1.length; i += 1) { + const indicator = str1[i - 1] === str2[j - 1] ? 0 : 1 + track[j][i] = Math.min( + track[j][i - 1] + 1, // deletion + track[j - 1][i] + 1, // insertion + track[j - 1][i - 1] + indicator // substitution + ) + } + } + return track[str2.length][str1.length] +} export const chatRequest = async ( request: Request, @@ -16,8 +39,10 @@ export const chatRequest = async ( const chatSettings = chat.settings const model = chatRequest.getModel() const modelDetail = getModelDetail(model) - const ws = new WebSocket(getEndpoint(model)) const signal = chatRequest.controller.signal + const providerData = chatRequest.providerData.petals || {} + chatRequest.providerData.petals = providerData + let ws: WebSocket = providerData.ws const abortListener = (e:Event) => { chatRequest.updating = false chatRequest.updatingMessage = '' @@ -26,9 +51,17 @@ export const chatRequest = async ( ws.close() } signal.addEventListener('abort', abortListener) + const startSequence = getStartSequence(chat) let stopSequences = [...new Set(getStopSequence(chat).split(',').filter(s => s.trim()).concat((modelDetail.stop || ['###', '']).slice()))] - const stopSequence = '' + let stopSequence = stopSequences[0] || '###' + if (startSequence.length) { + const sld = stopSequences.slice() + .filter(s => s === '###' || '' || countTokens(model, s) === 1) + .sort((a, b) => levenshteinDistance(a, startSequence) - levenshteinDistance(b, startSequence)) + stopSequence = sld[0] || stopSequence + } stopSequences.push(stopSequence) + const delimiter = getDelimiter(chat) const leadPromptSequence = getLeadPrompt(chat) if (delimiter) stopSequences.unshift(delimiter.trim()) @@ -62,56 +95,55 @@ export const chatRequest = async ( const buildMessage = (m: Message): string => { return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat) } + const buildInputArray = (a) => { + return a.reduce((a, m, i) => { + let c = buildMessage(m) + let replace = false + const lm = a[a.length - 1] + // Merge content if needed + if (lm) { + if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) { + c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content) + replace = true + } else { + c = c.replaceAll('[[SYSTEM_PROMPT]]', '') + } + if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) { + c = c.replaceAll('[[USER_PROMPT]]', lm.content) + replace = true + } else { + c = c.replaceAll('[[USER_PROMPT]]', '') + } + } + // Clean up merge fields on last + if (!rMessages[i + 1]) { + c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '') + } + const result = { + role: m.role, + content: c.trim() + } as Message + if (replace) { + a[a.length - 1] = result + } else { + a.push(result) + } + return a + }, [] as Message[]) + } const lastMessage = rMessages[rMessages.length - 1] let doLead = true if (lastMessage && lastMessage.role === 'assistant') { lastMessage.content = leadPromptSequence + lastMessage.content doLead = false } - const inputArray = rMessages.reduce((a, m, i) => { - let c = buildMessage(m) - let replace = false - const lm = a[a.length - 1] - // Merge content if needed - if (lm) { - if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) { - c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content) - replace = true - } else { - c = c.replaceAll('[[SYSTEM_PROMPT]]', '') - } - if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) { - c = c.replaceAll('[[USER_PROMPT]]', lm.content) - replace = true - } else { - c = c.replaceAll('[[USER_PROMPT]]', '') - } - } - // Clean up merge fields on last - if (!rMessages[i + 1]) { - c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '') - } - const result = { - role: m.role, - content: c.trim() - } as Message - if (replace) { - a[a.length - 1] = result - } else { - a.push(result) - } - return a - }, [] as Message[]) + // const inputArray = buildInputArray(rMessages).map(m => m.content) + const lInputArray = buildInputArray(rMessages.slice(0, -1)).map(m => m.content) + const nInputArray = buildInputArray(rMessages.slice(-1)).map(m => m.content) const leadPrompt = (leadPromptSequence && doLead) ? delimiter + leadPromptSequence : '' - const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(delimiter) + leadPrompt + const lastPrompt = startSequence + lInputArray.join(delimiter) + const nextPrompt = nInputArray.slice(-1).join('') + leadPrompt - let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens) - const promptTokenCount = countTokens(model, fullPromptInput) - if (promptTokenCount > maxLen) { - maxLen = Math.min(maxLen + promptTokenCount, maxTokens) - } - // update with real count - chatResponse.setPromptTokenCount(promptTokenCount) // set up the request chatResponse.onFinish(() => { const message = chatResponse.getMessages()[0] @@ -124,51 +156,120 @@ export const chatRequest = async ( } } } - ws.close() + !chatSettings.holdSocket && ws.close() }) - ws.onopen = () => { - ws.send(JSON.stringify({ - type: 'open_inference_session', - model, - max_length: maxLen - })) - ws.onmessage = event => { + + let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens) + + let inputPrompt = startSequence + + const getNewWs = ():Promise => new Promise((resolve, reject) => { + // console.warn('requesting new ws') + const nws = new WebSocket(getEndpoint(model)) + let opened = false + let done = false + nws.onmessage = event => { + if (done) return + done = true const response = JSON.parse(event.data) if (!response.ok) { const err = new Error('Error opening socket: ' + response.traceback) chatResponse.updateFromError(err.message) + console.error(err) + reject(err) + } + nws.onerror = err => { console.error(err) throw err } - const petalsRequest = { - type: 'generate', - inputs: fullPromptInput, - max_new_tokens: 1, // wait for up to 1 tokens before displaying - stop_sequence: stopSequence, - do_sample: 1, // enable top p and the like - temperature, - top_p: topP, - repetition_penalty: chatSettings.repetitionPenalty - } as any - if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC - // Update token count + // console.warn('got new ws') + inputPrompt = lastPrompt + providerData.knownBuffer = '' + providerData.ws = nws + resolve(nws) + } + nws.onclose = () => { + chatResponse.updateFromClose() + } + nws.onerror = err => { + if (done) return + done = true + console.error(err) + reject(err) + } + nws.onopen = () => { + if (opened) return + opened = true + const promptTokenCount = countTokens(model, lastPrompt + delimiter + nextPrompt) + if (promptTokenCount > maxLen) { + maxLen = Math.min(maxLen + promptTokenCount, maxTokens) + } + // update with real count chatResponse.setPromptTokenCount(promptTokenCount) - ws.send(JSON.stringify(petalsRequest)) - ws.onmessage = event => { - // Remove updating indicator - chatRequest.updating = 1 // hide indicator, but still signal we're updating - chatRequest.updatingMessage = '' - const response = JSON.parse(event.data) - if (!response.ok) { - if (response.traceback.includes('Maximum length exceeded')) { - return chatResponse.finish('length') - } - const err = new Error('Error in response: ' + response.traceback) - console.error(err) - chatResponse.updateFromError(err.message) - throw err - } - chatResponse.updateFromAsyncResponse( + nws.send(JSON.stringify({ + type: 'open_inference_session', + model, + max_length: chatSettings.holdSocket ? maxTokens : maxLen + })) + } + }) + + const wsOpen = (ws && ws.readyState === WebSocket.OPEN) + + if (!chatSettings.holdSocket || wsOpen) { + const rgxp = new RegExp('(||\\s|' + escapeRegex(stopSequence) + ')', 'g') + const kb = providerData.knownBuffer.replace(rgxp, '') + const lp = lastPrompt.replace(rgxp, '') + const lm = kb === lp + if (!lm || countTokens(model, providerData.knownBuffer + inputPrompt) >= maxTokens) { + wsOpen && ws.close() + ws = await getNewWs() + } + } + + if (!ws || ws.readyState !== WebSocket.OPEN) { + ws = await getNewWs() + } + + inputPrompt += delimiter + nextPrompt + providerData.knownBuffer += inputPrompt + + // console.log( + // '\n\n*** inputPrompt: ***\n\n', + // inputPrompt + + // ) + + const petalsRequest = { + type: 'generate', + inputs: inputPrompt, + max_new_tokens: 1, // wait for up to 1 tokens before displaying + stop_sequence: stopSequence, + do_sample: 1, // enable top p and the like + temperature, + top_p: topP, + repetition_penalty: chatSettings.repetitionPenalty + } as any + if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC + // Update token count + chatResponse.setPromptTokenCount(countTokens(model, providerData.knownBuffer)) + ws.onmessage = event => { + // Remove updating indicator + chatRequest.updating = chatRequest.updating && 1 // hide indicator, but still signal we're updating + chatRequest.updatingMessage = '' + const response = JSON.parse(event.data) + if (!response.ok) { + if (response.traceback.includes('Maximum length exceeded')) { + return chatResponse.finish('length') + } + if (!chatRequest.updating) return + const err = new Error('Error in response: ' + response.traceback) + console.error(err) + chatResponse.updateFromError(err.message) + throw err + } + providerData.knownBuffer += response.outputs + chatResponse.updateFromAsyncResponse( { model, choices: [{ @@ -179,37 +280,32 @@ export const chatRequest = async ( finish_reason: (response.stop ? 'stop' : null) }] } as any - ) - if (chatSettings.aggressiveStop && !response.stop) { - // check if we should've stopped - const message = chatResponse.getMessages()[0] - const pad = 10 // look back 10 characters + stop sequence - if (message) { - const mc = (message.content).trim() - for (let i = 0, l = stopSequences.length; i < l; i++) { - const ss = stopSequences[i].trim() - const ind = mc.slice(0 - (ss.length + pad)).indexOf(ss) - if (ind > -1) { - const offset = (ss.length + pad) - ind - message.content = mc.slice(0, mc.length - offset) - response.stop = true - updateMessages(chat.id) - chatResponse.finish() - ws.close() - } + ) + if (chatSettings.aggressiveStop && !response.stop) { + // check if we should've stopped + const message = chatResponse.getMessages()[0] + const pad = 10 // look back 10 characters + stop sequence + if (message) { + const mc = (message.content).trim() + for (let i = 0, l = stopSequences.length; i < l; i++) { + const ss = stopSequences[i].trim() + const ind = mc.slice(0 - (ss.length + pad)).indexOf(ss) + if (ind > -1) { + const offset = (ss.length + pad) - ind + message.content = mc.slice(0, mc.length - offset) + response.stop = true + updateMessages(chat.id) + chatResponse.finish() + if (ss !== stopSequence) { + providerData.knownBuffer += stopSequence } + ws.close() } } } } - ws.onclose = () => { - chatResponse.updateFromClose() - } - ws.onerror = err => { - console.error(err) - throw err - } } + ws.send(JSON.stringify(petalsRequest)) return chatResponse } \ No newline at end of file