From 3b72f73080ae2999b850a4e65a1dcad11005425f Mon Sep 17 00:00:00 2001 From: Webifi Date: Sat, 19 Aug 2023 19:16:18 -0500 Subject: [PATCH] fix prompt continuation for petals --- src/lib/providers/petals/models.svelte | 2 +- src/lib/providers/petals/request.svelte | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/lib/providers/petals/models.svelte b/src/lib/providers/petals/models.svelte index 724a45a..4d36c75 100644 --- a/src/lib/providers/petals/models.svelte +++ b/src/lib/providers/petals/models.svelte @@ -45,7 +45,7 @@ const chatModelBase = { return prompts.reduce((a, m) => { a += countMessageTokens(m, model, chat) return a - }, 0) + countTokens(model, getStartSequence(chat)) + ((prompts[prompts.length - 1] || {}).role !== 'assistant' ? countTokens(model, getLeadPrompt(chat)) : 0) + }, 0) + countTokens(model, getStartSequence(chat)) + countTokens(model, getLeadPrompt(chat)) } } as ModelDetail diff --git a/src/lib/providers/petals/request.svelte b/src/lib/providers/petals/request.svelte index 1f8467f..b24e302 100644 --- a/src/lib/providers/petals/request.svelte +++ b/src/lib/providers/petals/request.svelte @@ -62,6 +62,12 @@ export const chatRequest = async ( const buildMessage = (m: Message): string => { return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat) } + const lastMessage = rMessages[rMessages.length - 1] + let doLead = true + if (lastMessage && lastMessage.role === 'assistant') { + lastMessage.content = leadPromptSequence + lastMessage.content + doLead = false + } const inputArray = rMessages.reduce((a, m, i) => { let c = buildMessage(m) let replace = false @@ -96,7 +102,7 @@ export const chatRequest = async ( } return a }, [] as Message[]) - const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? delimiter + leadPromptSequence : '' + const leadPrompt = (leadPromptSequence && doLead) ? delimiter + leadPromptSequence : '' const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(delimiter) + leadPrompt let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)