headless_lms_chatbot/
llm_utils.rs

1use crate::{
2    azure_chatbot::{LLMRequest, ToolCallType},
3    chatbot_error::ChatbotResult,
4    prelude::*,
5};
6use core::default::Default;
7use headless_lms_models::{
8    chatbot_conversation_message_tool_calls::ChatbotConversationMessageToolCall,
9    chatbot_conversation_message_tool_outputs::ChatbotConversationMessageToolOutput,
10    chatbot_conversation_messages::{ChatbotConversationMessage, MessageRole},
11};
12use headless_lms_utils::ApplicationConfiguration;
13use reqwest::Response;
14use reqwest::header::HeaderMap;
15use serde::{Deserialize, Serialize};
16use tracing::{debug, error, instrument, trace, warn};
17
18// API version for Azure OpenAI calls
19pub const LLM_API_VERSION: &str = "2024-06-01";
20
21/// Common message structure used for LLM API requests
22#[derive(Serialize, Deserialize, Debug, Clone)]
23pub struct APIMessage {
24    pub role: MessageRole,
25    #[serde(flatten)]
26    pub fields: APIMessageKind,
27}
28
29impl APIMessage {
30    /// Create a ChatbotConversationMessage from an APIMessage to save it into the DB.
31    /// Notice that the insert operation ignores some of the fields, like timestamps.
32    /// `to_chatbot_conversation_message` doesn't set the correct order_number field
33    /// value.
34    pub fn to_chatbot_conversation_message(
35        &self,
36        conversation_id: Uuid,
37        order_number: i32,
38    ) -> ChatbotResult<ChatbotConversationMessage> {
39        let res = match self.fields.clone() {
40            APIMessageKind::Text(msg) => ChatbotConversationMessage {
41                message_role: self.role,
42                conversation_id,
43                order_number,
44                message_is_complete: true,
45                used_tokens: estimate_tokens(&msg.content),
46                message: Some(msg.content),
47                ..Default::default()
48            },
49            APIMessageKind::ToolCall(msg) => {
50                let tool_call_fields = msg
51                    .tool_calls
52                    .iter()
53                    .map(|x| ChatbotConversationMessageToolCall::try_from(x.to_owned()))
54                    .collect::<ChatbotResult<Vec<_>>>()?;
55                let estimated_tokens: i32 = msg
56                    .tool_calls
57                    .iter()
58                    .map(|x| estimate_tokens(&x.function.arguments))
59                    .sum();
60                ChatbotConversationMessage {
61                    message_role: self.role,
62                    conversation_id,
63                    order_number,
64                    message_is_complete: true,
65                    message: None,
66                    tool_call_fields,
67                    used_tokens: estimated_tokens,
68                    ..Default::default()
69                }
70            }
71            APIMessageKind::ToolResponse(msg) => ChatbotConversationMessage {
72                message_role: self.role,
73                conversation_id,
74                order_number,
75                message_is_complete: true,
76                message: None,
77                used_tokens: 0,
78                tool_output: Some(ChatbotConversationMessageToolOutput::from(msg)),
79                ..Default::default()
80            },
81        };
82        Result::Ok(res)
83    }
84}
85
86impl TryFrom<ChatbotConversationMessage> for APIMessage {
87    type Error = ChatbotError;
88    fn try_from(message: ChatbotConversationMessage) -> ChatbotResult<Self> {
89        let res = match message.message_role {
90            MessageRole::Assistant => {
91                if !message.tool_call_fields.is_empty() {
92                    APIMessage {
93                        role: message.message_role,
94                        fields: APIMessageKind::ToolCall(APIMessageToolCall {
95                            tool_calls: message
96                                .tool_call_fields
97                                .iter()
98                                .map(|f| APIToolCall::from(f.clone()))
99                                .collect(),
100                        }),
101                    }
102                } else if let Some(msg) = message.message {
103                    APIMessage {
104                        role: message.message_role,
105                        fields: APIMessageKind::Text(APIMessageText { content: msg }),
106                    }
107                } else {
108                    return Err(ChatbotError::new(
109                        ChatbotErrorType::InvalidMessageShape,
110                        "A 'role: assistant' type ChatbotConversationMessage must have either tool_call_fields or a text message.",
111                        None,
112                    ));
113                }
114            }
115            MessageRole::Tool => {
116                if let Some(tool_output) = message.tool_output {
117                    APIMessage {
118                        role: message.message_role,
119                        fields: APIMessageKind::ToolResponse(APIMessageToolResponse {
120                            tool_call_id: tool_output.tool_call_id,
121                            name: tool_output.tool_name,
122                            content: tool_output.tool_output,
123                        }),
124                    }
125                } else {
126                    return Err(ChatbotError::new(
127                        ChatbotErrorType::InvalidMessageShape,
128                        "A 'role: tool' type ChatbotConversationMessage must have field tool_output.",
129                        None,
130                    ));
131                }
132            }
133            MessageRole::User => APIMessage {
134                role: message.message_role,
135                fields: APIMessageKind::Text(APIMessageText {
136                    content: message.message.unwrap_or_default(),
137                }),
138            },
139            MessageRole::System => {
140                return Err(ChatbotError::new(
141                    ChatbotErrorType::InvalidMessageShape,
142                    "A 'role: system' type ChatbotConversationMessage cannot be saved into the database.",
143                    None,
144                ));
145            }
146        };
147        Result::Ok(res)
148    }
149}
150
151#[derive(Serialize, Deserialize, Debug, Clone)]
152#[serde(untagged)]
153pub enum APIMessageKind {
154    Text(APIMessageText),
155    ToolCall(APIMessageToolCall),
156    ToolResponse(APIMessageToolResponse),
157}
158
159/// LLM api message that contains only text
160#[derive(Serialize, Deserialize, Debug, Clone)]
161pub struct APIMessageText {
162    pub content: String,
163}
164
165/// LLM api message that contains tool calls. The tool calls were originally made by
166/// the LLM, but have been processed and added to the messages in a LLMRequest
167#[derive(Serialize, Deserialize, Debug, Clone)]
168pub struct APIMessageToolCall {
169    pub tool_calls: Vec<APIToolCall>,
170}
171
172/// LLM api message that contains outputs of tool calls
173#[derive(Serialize, Deserialize, Debug, Clone)]
174pub struct APIMessageToolResponse {
175    pub tool_call_id: String,
176    pub name: String,
177    pub content: String,
178}
179
180impl From<ChatbotConversationMessageToolOutput> for APIMessageToolResponse {
181    fn from(value: ChatbotConversationMessageToolOutput) -> Self {
182        APIMessageToolResponse {
183            tool_call_id: value.tool_call_id,
184            name: value.tool_name,
185            content: value.tool_output,
186        }
187    }
188}
189
190impl From<APIMessageToolResponse> for ChatbotConversationMessageToolOutput {
191    fn from(value: APIMessageToolResponse) -> Self {
192        ChatbotConversationMessageToolOutput {
193            tool_name: value.name,
194            tool_output: value.content,
195            tool_call_id: value.tool_call_id,
196            ..Default::default()
197        }
198    }
199}
200
201/// An LLM tool call that is part of a request to Azure
202#[derive(Serialize, Deserialize, Debug, Clone)]
203pub struct APIToolCall {
204    pub function: APITool,
205    pub id: String,
206    #[serde(rename = "type")]
207    pub tool_type: ToolCallType,
208}
209
210impl From<ChatbotConversationMessageToolCall> for APIToolCall {
211    fn from(value: ChatbotConversationMessageToolCall) -> Self {
212        APIToolCall {
213            function: APITool {
214                arguments: value.tool_arguments.to_string(),
215                name: value.tool_name,
216            },
217            id: value.tool_call_id,
218            tool_type: ToolCallType::Function,
219        }
220    }
221}
222
223impl TryFrom<APIToolCall> for ChatbotConversationMessageToolCall {
224    type Error = ChatbotError;
225    fn try_from(value: APIToolCall) -> ChatbotResult<Self> {
226        Ok(ChatbotConversationMessageToolCall {
227            tool_name: value.function.name,
228            tool_arguments: serde_json::from_str(&value.function.arguments)?,
229            tool_call_id: value.id,
230            ..Default::default()
231        })
232    }
233}
234
235#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
236pub struct APITool {
237    pub arguments: String,
238    pub name: String,
239}
240
241/// Simple completion-focused LLM request for Azure OpenAI
242/// Note: In Azure OpenAI, the model is specified in the URL, not in the request body
243#[derive(Serialize, Deserialize, Debug)]
244pub struct AzureCompletionRequest {
245    #[serde(flatten)]
246    pub base: LLMRequest,
247    pub stream: bool,
248}
249
250/// Response from LLM for simple completions
251#[derive(Deserialize, Debug)]
252pub struct LLMCompletionResponse {
253    pub choices: Vec<LLMChoice>,
254}
255
256#[derive(Deserialize, Debug)]
257pub struct LLMChoice {
258    pub message: APIMessage,
259}
260
261/// Builds common headers for LLM requests
262#[instrument(skip(api_key), fields(api_key_length = api_key.len()))]
263pub fn build_llm_headers(api_key: &str) -> anyhow::Result<HeaderMap> {
264    trace!("Building LLM request headers");
265    let mut headers = HeaderMap::new();
266    headers.insert(
267        "api-key",
268        api_key.parse().map_err(|_e| {
269            error!("Failed to parse API key");
270            anyhow::anyhow!("Invalid API key")
271        })?,
272    );
273    headers.insert(
274        "content-type",
275        "application/json".parse().map_err(|_e| {
276            error!("Failed to parse content-type header");
277            anyhow::anyhow!("Internal error")
278        })?,
279    );
280    trace!("Successfully built headers");
281    Ok(headers)
282}
283
284/// Prepares Azure OpenAI endpoint with API version
285#[instrument(skip(endpoint))]
286pub fn prepare_azure_endpoint(mut endpoint: url::Url) -> url::Url {
287    trace!(
288        "Preparing Azure endpoint with API version {}",
289        LLM_API_VERSION
290    );
291    // Always set the API version so that we actually use the API that the code is written for
292    endpoint.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
293    trace!("Endpoint prepared: {}", endpoint);
294    endpoint
295}
296
297/// Estimate the number of tokens in a given text.
298#[instrument(skip(text), fields(text_length = text.len()))]
299pub fn estimate_tokens(text: &str) -> i32 {
300    trace!("Estimating tokens for text");
301    let text_length = text.chars().fold(0, |acc, c| {
302        let mut len = c.len_utf8() as i32;
303        if len > 1 {
304            // The longer the character is, the more likely the text around is taking up more tokens
305            len *= 2;
306        }
307        if c.is_ascii_punctuation() {
308            // Punctuation is less common and is thus less likely to be part of a token
309            len *= 2;
310        }
311        acc + len
312    });
313    // A token is roughly 4 characters
314    let estimated_tokens = text_length / 4;
315    trace!("Estimated {} tokens for text", estimated_tokens);
316    estimated_tokens
317}
318
319/// Makes a non-streaming request to an LLM
320#[instrument(skip(chat_request, endpoint, api_key), fields(
321    num_messages = chat_request.messages.len(),
322    temperature,
323    max_tokens,
324    endpoint = %endpoint
325))]
326async fn make_llm_request(
327    chat_request: LLMRequest,
328    endpoint: &url::Url,
329    api_key: &str,
330) -> anyhow::Result<LLMCompletionResponse> {
331    debug!(
332        "Preparing LLM request with {} messages",
333        chat_request.messages.len()
334    );
335
336    trace!("Base request: {:?}", chat_request);
337
338    let request = AzureCompletionRequest {
339        base: chat_request,
340        stream: false,
341    };
342
343    let headers = build_llm_headers(api_key)?;
344    debug!("Sending request to LLM endpoint: {}", endpoint);
345
346    let response = REQWEST_CLIENT
347        .post(prepare_azure_endpoint(endpoint.clone()))
348        .headers(headers)
349        .json(&request)
350        .send()
351        .await?;
352
353    trace!("Received response from LLM");
354    process_llm_response(response).await
355}
356
357/// Process a non-streaming LLM response
358#[instrument(skip(response), fields(status = %response.status()))]
359async fn process_llm_response(response: Response) -> anyhow::Result<LLMCompletionResponse> {
360    if !response.status().is_success() {
361        let status = response.status();
362        let error_text = response.text().await?;
363        error!(
364            status = %status,
365            error = %error_text,
366            "Error calling LLM API"
367        );
368        return Err(anyhow::anyhow!(
369            "Error calling LLM API: Status: {}. Error: {}",
370            status,
371            error_text
372        ));
373    }
374
375    trace!("Processing successful LLM response");
376    // Parse the response
377    let completion: LLMCompletionResponse = response.json().await?;
378    debug!(
379        "Successfully processed LLM response with {} choices",
380        completion.choices.len()
381    );
382    Ok(completion)
383}
384
385/// Makes a streaming request to an LLM
386#[instrument(skip(chat_request, app_config), fields(
387    num_messages = chat_request.messages.len(),
388    temperature,
389    max_tokens
390))]
391pub async fn make_streaming_llm_request(
392    chat_request: LLMRequest,
393    model_deployment_name: &str,
394    app_config: &ApplicationConfiguration,
395) -> anyhow::Result<Response> {
396    debug!(
397        "Preparing streaming LLM request with {} messages",
398        chat_request.messages.len()
399    );
400    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
401        error!("Azure configuration missing");
402        anyhow::anyhow!("Azure configuration is missing from the application configuration")
403    })?;
404
405    let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
406        error!("Chatbot configuration missing");
407        anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
408    })?;
409
410    trace!("Base request: {:?}", chat_request);
411
412    let request = AzureCompletionRequest {
413        base: chat_request,
414        stream: true,
415    };
416
417    let headers = build_llm_headers(&chatbot_config.api_key)?;
418    let api_endpoint = chatbot_config
419        .api_endpoint
420        .join(&(model_deployment_name.to_owned() + "/chat/completions"))?;
421    debug!(
422        "Sending streaming request to LLM endpoint: {}",
423        api_endpoint
424    );
425
426    let response = REQWEST_CLIENT
427        .post(prepare_azure_endpoint(api_endpoint.clone()))
428        .headers(headers)
429        .json(&request)
430        .send()
431        .await?;
432
433    if !response.status().is_success() {
434        let status = response.status();
435        let error_text = response.text().await?;
436        error!(
437            status = %status,
438            error = %error_text,
439            "Error calling streaming LLM API"
440        );
441        return Err(anyhow::anyhow!(
442            "Error calling LLM API: Status: {}. Error: {}",
443            status,
444            error_text
445        ));
446    }
447
448    debug!("Successfully initiated streaming response");
449    Ok(response)
450}
451
452/// Makes a non-streaming request to an LLM using application configuration
453#[instrument(skip(chat_request, app_config), fields(
454    num_messages = chat_request.messages.len(),
455    temperature,
456    max_tokens
457))]
458pub async fn make_blocking_llm_request(
459    chat_request: LLMRequest,
460    app_config: &ApplicationConfiguration,
461) -> anyhow::Result<LLMCompletionResponse> {
462    debug!(
463        "Preparing blocking LLM request with {} messages",
464        chat_request.messages.len()
465    );
466    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
467        error!("Azure configuration missing");
468        anyhow::anyhow!("Azure configuration is missing from the application configuration")
469    })?;
470
471    let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
472        error!("Chatbot configuration missing");
473        anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
474    })?;
475
476    let api_endpoint = chatbot_config
477        .api_endpoint
478        .join("gpt-4o/chat/completions")?;
479
480    trace!("Making LLM request to endpoint: {}", api_endpoint);
481    make_llm_request(chat_request, &api_endpoint, &chatbot_config.api_key).await
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_estimate_tokens() {
490        // The real number is 4
491        assert_eq!(estimate_tokens("Hello, world!"), 3);
492        assert_eq!(estimate_tokens(""), 0);
493        // The real number is 9
494        assert_eq!(
495            estimate_tokens("This is a longer sentence with several words."),
496            11
497        );
498        // The real number is 7
499        assert_eq!(estimate_tokens("Hyvää päivää!"), 7);
500        // The real number is 9
501        assert_eq!(estimate_tokens("トークンは楽しい"), 12);
502        // The real number is 52
503        assert_eq!(
504            estimate_tokens("🙂🙃😀😃😄😁😆😅😂🤣😊😇🙂🙃😀😃😄😁😆😅😂🤣😊😇"),
505            48
506        );
507        // The real number is 18
508        assert_eq!(estimate_tokens("ฉันใช้โทเค็นทุกวัน"), 27);
509        // The real number is 17
510        assert_eq!(estimate_tokens("Жетони роблять мене щасливим"), 25);
511    }
512}