headless_lms_chatbot/
llm_utils.rs

1use crate::{
2    azure_chatbot::{LLMRequest, ToolCallType},
3    chatbot_error::ChatbotResult,
4    prelude::*,
5};
6use core::default::Default;
7use headless_lms_models::{
8    application_task_default_language_models::TaskLMSpec,
9    chatbot_conversation_message_tool_calls::ChatbotConversationMessageToolCall,
10    chatbot_conversation_message_tool_outputs::ChatbotConversationMessageToolOutput,
11    chatbot_conversation_messages::{ChatbotConversationMessage, MessageRole},
12};
13use headless_lms_utils::ApplicationConfiguration;
14use reqwest::Response;
15use reqwest::header::HeaderMap;
16use serde::{Deserialize, Serialize};
17use tracing::{debug, error, instrument, trace, warn};
18
19// API version for Azure OpenAI calls
20pub const LLM_API_VERSION: &str = "2024-10-21";
21
22/// Common message structure used for LLM API requests
23#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct APIMessage {
25    pub role: MessageRole,
26    #[serde(flatten)]
27    pub fields: APIMessageKind,
28}
29
30impl APIMessage {
31    /// Create a ChatbotConversationMessage from an APIMessage to save it into the DB.
32    /// Notice that the insert operation ignores some of the fields, like timestamps.
33    /// `to_chatbot_conversation_message` doesn't set the correct order_number field
34    /// value.
35    pub fn to_chatbot_conversation_message(
36        &self,
37        conversation_id: Uuid,
38        order_number: i32,
39    ) -> ChatbotResult<ChatbotConversationMessage> {
40        let res = match self.fields.clone() {
41            APIMessageKind::Text(msg) => ChatbotConversationMessage {
42                message_role: self.role,
43                conversation_id,
44                order_number,
45                message_is_complete: true,
46                used_tokens: estimate_tokens(&msg.content),
47                message: Some(msg.content),
48                ..Default::default()
49            },
50            APIMessageKind::ToolCall(msg) => {
51                let tool_call_fields = msg
52                    .tool_calls
53                    .iter()
54                    .map(|x| ChatbotConversationMessageToolCall::try_from(x.to_owned()))
55                    .collect::<ChatbotResult<Vec<_>>>()?;
56                let estimated_tokens: i32 = msg
57                    .tool_calls
58                    .iter()
59                    .map(|x| estimate_tokens(&x.function.arguments))
60                    .sum();
61                ChatbotConversationMessage {
62                    message_role: self.role,
63                    conversation_id,
64                    order_number,
65                    message_is_complete: true,
66                    message: None,
67                    tool_call_fields,
68                    used_tokens: estimated_tokens,
69                    ..Default::default()
70                }
71            }
72            APIMessageKind::ToolResponse(msg) => ChatbotConversationMessage {
73                message_role: self.role,
74                conversation_id,
75                order_number,
76                message_is_complete: true,
77                message: None,
78                used_tokens: 0,
79                tool_output: Some(ChatbotConversationMessageToolOutput::from(msg)),
80                ..Default::default()
81            },
82        };
83        Result::Ok(res)
84    }
85}
86
87impl TryFrom<ChatbotConversationMessage> for APIMessage {
88    type Error = ChatbotError;
89    fn try_from(message: ChatbotConversationMessage) -> ChatbotResult<Self> {
90        let res = match message.message_role {
91            MessageRole::Assistant => {
92                if !message.tool_call_fields.is_empty() {
93                    APIMessage {
94                        role: message.message_role,
95                        fields: APIMessageKind::ToolCall(APIMessageToolCall {
96                            tool_calls: message
97                                .tool_call_fields
98                                .iter()
99                                .map(|f| APIToolCall::from(f.clone()))
100                                .collect(),
101                        }),
102                    }
103                } else if let Some(msg) = message.message {
104                    APIMessage {
105                        role: message.message_role,
106                        fields: APIMessageKind::Text(APIMessageText { content: msg }),
107                    }
108                } else {
109                    return Err(ChatbotError::new(
110                        ChatbotErrorType::InvalidMessageShape,
111                        "A 'role: assistant' type ChatbotConversationMessage must have either tool_call_fields or a text message.",
112                        None,
113                    ));
114                }
115            }
116            MessageRole::Tool => {
117                if let Some(tool_output) = message.tool_output {
118                    APIMessage {
119                        role: message.message_role,
120                        fields: APIMessageKind::ToolResponse(APIMessageToolResponse {
121                            tool_call_id: tool_output.tool_call_id,
122                            name: tool_output.tool_name,
123                            content: tool_output.tool_output,
124                        }),
125                    }
126                } else {
127                    return Err(ChatbotError::new(
128                        ChatbotErrorType::InvalidMessageShape,
129                        "A 'role: tool' type ChatbotConversationMessage must have field tool_output.",
130                        None,
131                    ));
132                }
133            }
134            MessageRole::User => APIMessage {
135                role: message.message_role,
136                fields: APIMessageKind::Text(APIMessageText {
137                    content: message.message.unwrap_or_default(),
138                }),
139            },
140            MessageRole::System => {
141                return Err(ChatbotError::new(
142                    ChatbotErrorType::InvalidMessageShape,
143                    "A 'role: system' type ChatbotConversationMessage cannot be saved into the database.",
144                    None,
145                ));
146            }
147        };
148        Result::Ok(res)
149    }
150}
151
152#[derive(Serialize, Deserialize, Debug, Clone)]
153#[serde(untagged)]
154pub enum APIMessageKind {
155    Text(APIMessageText),
156    ToolCall(APIMessageToolCall),
157    ToolResponse(APIMessageToolResponse),
158}
159
160/// LLM api message that contains only text
161#[derive(Serialize, Deserialize, Debug, Clone)]
162pub struct APIMessageText {
163    pub content: String,
164}
165
166/// LLM api message that contains tool calls. The tool calls were originally made by
167/// the LLM, but have been processed and added to the messages in a LLMRequest
168#[derive(Serialize, Deserialize, Debug, Clone)]
169pub struct APIMessageToolCall {
170    pub tool_calls: Vec<APIToolCall>,
171}
172
173/// LLM api message that contains outputs of tool calls
174#[derive(Serialize, Deserialize, Debug, Clone)]
175pub struct APIMessageToolResponse {
176    pub tool_call_id: String,
177    pub name: String,
178    pub content: String,
179}
180
181impl From<ChatbotConversationMessageToolOutput> for APIMessageToolResponse {
182    fn from(value: ChatbotConversationMessageToolOutput) -> Self {
183        APIMessageToolResponse {
184            tool_call_id: value.tool_call_id,
185            name: value.tool_name,
186            content: value.tool_output,
187        }
188    }
189}
190
191impl From<APIMessageToolResponse> for ChatbotConversationMessageToolOutput {
192    fn from(value: APIMessageToolResponse) -> Self {
193        ChatbotConversationMessageToolOutput {
194            tool_name: value.name,
195            tool_output: value.content,
196            tool_call_id: value.tool_call_id,
197            ..Default::default()
198        }
199    }
200}
201
202/// An LLM tool call that is part of a request to Azure
203#[derive(Serialize, Deserialize, Debug, Clone)]
204pub struct APIToolCall {
205    pub function: APITool,
206    pub id: String,
207    #[serde(rename = "type")]
208    pub tool_type: ToolCallType,
209}
210
211impl From<ChatbotConversationMessageToolCall> for APIToolCall {
212    fn from(value: ChatbotConversationMessageToolCall) -> Self {
213        APIToolCall {
214            function: APITool {
215                arguments: value.tool_arguments.to_string(),
216                name: value.tool_name,
217            },
218            id: value.tool_call_id,
219            tool_type: ToolCallType::Function,
220        }
221    }
222}
223
224impl TryFrom<APIToolCall> for ChatbotConversationMessageToolCall {
225    type Error = ChatbotError;
226    fn try_from(value: APIToolCall) -> ChatbotResult<Self> {
227        Ok(ChatbotConversationMessageToolCall {
228            tool_name: value.function.name,
229            tool_arguments: serde_json::from_str(&value.function.arguments)?,
230            tool_call_id: value.id,
231            ..Default::default()
232        })
233    }
234}
235
236#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
237pub struct APITool {
238    pub arguments: String,
239    pub name: String,
240}
241
242/// Simple completion-focused LLM request for Azure OpenAI
243/// Note: In Azure OpenAI, the model is specified in the URL, not in the request body
244#[derive(Serialize, Deserialize, Debug)]
245pub struct AzureCompletionRequest {
246    #[serde(flatten)]
247    pub base: LLMRequest,
248    pub stream: bool,
249}
250
251/// Response from LLM for simple completions
252#[derive(Deserialize, Debug)]
253pub struct LLMCompletionResponse {
254    pub choices: Vec<LLMChoice>,
255}
256
257#[derive(Deserialize, Debug)]
258pub struct LLMChoice {
259    pub message: APIMessage,
260}
261
262/// Builds common headers for LLM requests
263#[instrument(skip(api_key), fields(api_key_length = api_key.len()))]
264pub fn build_llm_headers(api_key: &str) -> anyhow::Result<HeaderMap> {
265    trace!("Building LLM request headers");
266    let mut headers = HeaderMap::new();
267    headers.insert(
268        "api-key",
269        api_key.parse().map_err(|_e| {
270            error!("Failed to parse API key");
271            anyhow::anyhow!("Invalid API key")
272        })?,
273    );
274    headers.insert(
275        "content-type",
276        "application/json".parse().map_err(|_e| {
277            error!("Failed to parse content-type header");
278            anyhow::anyhow!("Internal error")
279        })?,
280    );
281    trace!("Successfully built headers");
282    Ok(headers)
283}
284
285/// Prepares Azure OpenAI endpoint with API version
286#[instrument(skip(endpoint))]
287pub fn prepare_azure_endpoint(mut endpoint: url::Url) -> url::Url {
288    trace!(
289        "Preparing Azure endpoint with API version {}",
290        LLM_API_VERSION
291    );
292    // Always set the API version so that we actually use the API that the code is written for
293    endpoint.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
294    trace!("Endpoint prepared: {}", endpoint);
295    endpoint
296}
297
298/// Estimate the number of tokens in a given text.
299#[instrument(skip(text), fields(text_length = text.len()))]
300pub fn estimate_tokens(text: &str) -> i32 {
301    trace!("Estimating tokens for text");
302    let text_length = text.chars().fold(0, |acc, c| {
303        let mut len = c.len_utf8() as i32;
304        if len > 1 {
305            // The longer the character is, the more likely the text around is taking up more tokens
306            len *= 2;
307        }
308        if c.is_ascii_punctuation() {
309            // Punctuation is less common and is thus less likely to be part of a token
310            len *= 2;
311        }
312        acc + len
313    });
314    // A token is roughly 4 characters
315    let estimated_tokens = text_length / 4;
316    trace!("Estimated {} tokens for text", estimated_tokens);
317    estimated_tokens
318}
319
320/// Makes a non-streaming request to an LLM
321#[instrument(skip(chat_request, endpoint, api_key), fields(
322    num_messages = chat_request.messages.len(),
323    temperature,
324    max_tokens,
325    endpoint = %endpoint
326))]
327async fn make_llm_request(
328    chat_request: LLMRequest,
329    endpoint: &url::Url,
330    api_key: &str,
331) -> anyhow::Result<LLMCompletionResponse> {
332    debug!(
333        "Preparing LLM request with {} messages",
334        chat_request.messages.len()
335    );
336
337    trace!("Base request: {:?}", chat_request);
338
339    let request = AzureCompletionRequest {
340        base: chat_request,
341        stream: false,
342    };
343
344    let headers = build_llm_headers(api_key)?;
345    debug!("Sending request to LLM endpoint: {}", endpoint);
346
347    let response = REQWEST_CLIENT
348        .post(prepare_azure_endpoint(endpoint.clone()))
349        .headers(headers)
350        .json(&request)
351        .send()
352        .await?;
353
354    trace!("Received response from LLM");
355    process_llm_response(response).await
356}
357
358/// Process a non-streaming LLM response
359#[instrument(skip(response), fields(status = %response.status()))]
360async fn process_llm_response(response: Response) -> anyhow::Result<LLMCompletionResponse> {
361    if !response.status().is_success() {
362        let status = response.status();
363        let error_text = response.text().await?;
364        error!(
365            status = %status,
366            error = %error_text,
367            "Error calling LLM API"
368        );
369        return Err(anyhow::anyhow!(
370            "Error calling LLM API: Status: {}. Error: {}",
371            status,
372            error_text
373        ));
374    }
375
376    trace!("Processing successful LLM response");
377    // Parse the response
378    let completion: LLMCompletionResponse = response.json().await?;
379    debug!(
380        "Successfully processed LLM response with {} choices",
381        completion.choices.len()
382    );
383    Ok(completion)
384}
385
386/// Makes a streaming request to an LLM
387#[instrument(skip(chat_request, app_config), fields(
388    num_messages = chat_request.messages.len(),
389    temperature,
390    max_tokens
391))]
392pub async fn make_streaming_llm_request(
393    chat_request: LLMRequest,
394    model_deployment_name: &str,
395    app_config: &ApplicationConfiguration,
396) -> anyhow::Result<Response> {
397    debug!(
398        "Preparing streaming LLM request with {} messages",
399        chat_request.messages.len()
400    );
401    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
402        error!("Azure configuration missing");
403        anyhow::anyhow!("Azure configuration is missing from the application configuration")
404    })?;
405
406    let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
407        error!("Chatbot configuration missing");
408        anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
409    })?;
410
411    trace!("Base request: {:?}", chat_request);
412
413    let request = AzureCompletionRequest {
414        base: chat_request,
415        stream: true,
416    };
417
418    let headers = build_llm_headers(&chatbot_config.api_key)?;
419    let api_endpoint = chatbot_config
420        .api_endpoint
421        .join(&(model_deployment_name.to_owned() + "/chat/completions"))?;
422    debug!(
423        "Sending streaming request to LLM endpoint: {}",
424        api_endpoint
425    );
426
427    let response = REQWEST_CLIENT
428        .post(prepare_azure_endpoint(api_endpoint.clone()))
429        .headers(headers)
430        .json(&request)
431        .send()
432        .await?;
433
434    if !response.status().is_success() {
435        let status = response.status();
436        let error_text = response.text().await?;
437        error!(
438            status = %status,
439            error = %error_text,
440            "Error calling streaming LLM API"
441        );
442        return Err(anyhow::anyhow!(
443            "Error calling LLM API: Status: {}. Error: {}",
444            status,
445            error_text
446        ));
447    }
448
449    debug!("Successfully initiated streaming response");
450    Ok(response)
451}
452
453/// Makes a non-streaming request to an LLM using application configuration
454#[instrument(skip(chat_request, app_config, task_lm), fields(
455    num_messages = chat_request.messages.len(),
456    temperature,
457    max_tokens
458))]
459pub async fn make_blocking_llm_request(
460    chat_request: LLMRequest,
461    app_config: &ApplicationConfiguration,
462    task_lm: &TaskLMSpec,
463) -> anyhow::Result<LLMCompletionResponse> {
464    debug!(
465        "Preparing blocking LLM request with {} messages",
466        chat_request.messages.len()
467    );
468    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
469        error!("Azure configuration missing");
470        anyhow::anyhow!("Azure configuration is missing from the application configuration")
471    })?;
472
473    let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
474        error!("Chatbot configuration missing");
475        anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
476    })?;
477
478    let model = task_lm.deployment_name.to_owned();
479    let path = model + "/chat/completions";
480
481    let api_endpoint = chatbot_config.api_endpoint.join(&path)?;
482
483    trace!("Making LLM request to endpoint: {}", api_endpoint);
484    make_llm_request(chat_request, &api_endpoint, &chatbot_config.api_key).await
485}
486
487/// Collects all the completion choices to a string. Assumes the completion has only
488/// text message content, no tool calls or responses.
489pub fn parse_text_completion(completion: LLMCompletionResponse) -> ChatbotResult<String> {
490    let res =
491    completion
492        .choices
493        .into_iter()
494        .map(|x| match x.message.fields {
495            APIMessageKind::Text(message) => Ok(message.content),
496            _ =>  Err(ChatbotError::new( ChatbotErrorType::InvalidMessageShape, "It was assumed this LLM response contains only text, but a tool call or tool response was detected.", None)),
497        })
498        .collect::<ChatbotResult<Vec<String>>>()?
499        .join("");
500    if res.is_empty() {
501        return Err(ChatbotError::new(
502            ChatbotErrorType::InvalidMessageShape,
503            "No content returned from LLM",
504            None,
505        ));
506    };
507    Ok(res)
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_estimate_tokens() {
516        // The real number is 4
517        assert_eq!(estimate_tokens("Hello, world!"), 3);
518        assert_eq!(estimate_tokens(""), 0);
519        // The real number is 9
520        assert_eq!(
521            estimate_tokens("This is a longer sentence with several words."),
522            11
523        );
524        // The real number is 7
525        assert_eq!(estimate_tokens("Hyvää päivää!"), 7);
526        // The real number is 9
527        assert_eq!(estimate_tokens("トークンは楽しい"), 12);
528        // The real number is 52
529        assert_eq!(
530            estimate_tokens("🙂🙃😀😃😄😁😆😅😂🤣😊😇🙂🙃😀😃😄😁😆😅😂🤣😊😇"),
531            48
532        );
533        // The real number is 18
534        assert_eq!(estimate_tokens("ฉันใช้โทเค็นทุกวัน"), 27);
535        // The real number is 17
536        assert_eq!(estimate_tokens("Жетони роблять мене щасливим"), 25);
537    }
538}