headless_lms_chatbot/
llm_utils.rs

1use crate::prelude::*;
2use headless_lms_utils::ApplicationConfiguration;
3use reqwest::Response;
4use reqwest::header::HeaderMap;
5use serde::{Deserialize, Serialize};
6use tracing::{debug, error, instrument, trace, warn};
7
8// API version for Azure OpenAI calls
9pub const LLM_API_VERSION: &str = "2024-06-01";
10
11/// Role of a message in a conversation
12#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
13pub enum MessageRole {
14    #[serde(rename = "system")]
15    System,
16    #[serde(rename = "user")]
17    User,
18    #[serde(rename = "assistant")]
19    Assistant,
20}
21
22/// Common message structure used for LLM requests
23#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct Message {
25    pub role: MessageRole,
26    pub content: String,
27}
28
29/// Base LLM request structure (common fields)
30#[derive(Serialize, Deserialize, Debug)]
31pub struct BaseLlmRequest {
32    pub messages: Vec<Message>,
33    pub temperature: f32,
34    pub max_tokens: Option<i32>,
35}
36
37/// Simple completion-focused LLM request for Azure OpenAI
38/// Note: In Azure OpenAI, the model is specified in the URL, not in the request body
39#[derive(Serialize, Deserialize, Debug)]
40pub struct AzureCompletionRequest {
41    #[serde(flatten)]
42    pub base: BaseLlmRequest,
43    pub stream: bool,
44}
45
46/// Response from LLM for simple completions
47#[derive(Deserialize, Debug)]
48pub struct LlmCompletionResponse {
49    pub choices: Vec<LlmChoice>,
50}
51
52#[derive(Deserialize, Debug)]
53pub struct LlmChoice {
54    pub message: Message,
55}
56
57/// Builds common headers for LLM requests
58#[instrument(skip(api_key), fields(api_key_length = api_key.len()))]
59pub fn build_llm_headers(api_key: &str) -> anyhow::Result<HeaderMap> {
60    trace!("Building LLM request headers");
61    let mut headers = HeaderMap::new();
62    headers.insert(
63        "api-key",
64        api_key.parse().map_err(|_e| {
65            error!("Failed to parse API key");
66            anyhow::anyhow!("Invalid API key")
67        })?,
68    );
69    headers.insert(
70        "content-type",
71        "application/json".parse().map_err(|_e| {
72            error!("Failed to parse content-type header");
73            anyhow::anyhow!("Internal error")
74        })?,
75    );
76    trace!("Successfully built headers");
77    Ok(headers)
78}
79
80/// Prepares Azure OpenAI endpoint with API version
81#[instrument(skip(endpoint))]
82pub fn prepare_azure_endpoint(mut endpoint: url::Url) -> url::Url {
83    trace!(
84        "Preparing Azure endpoint with API version {}",
85        LLM_API_VERSION
86    );
87    endpoint.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
88    trace!("Endpoint prepared: {}", endpoint);
89    endpoint
90}
91
92/// Estimate the number of tokens in a given text.
93#[instrument(skip(text), fields(text_length = text.len()))]
94pub fn estimate_tokens(text: &str) -> i32 {
95    trace!("Estimating tokens for text");
96    let text_length = text.chars().fold(0, |acc, c| {
97        let mut len = c.len_utf8() as i32;
98        if len > 1 {
99            // The longer the character is, the more likely the text around is taking up more tokens
100            len *= 2;
101        }
102        if c.is_ascii_punctuation() {
103            // Punctuation is less common and is thus less likely to be part of a token
104            len *= 2;
105        }
106        acc + len
107    });
108    // A token is roughly 4 characters
109    let estimated_tokens = text_length / 4;
110    trace!("Estimated {} tokens for text", estimated_tokens);
111    estimated_tokens
112}
113
114/// Makes a non-streaming request to an LLM
115#[instrument(skip(messages, endpoint, api_key), fields(
116    num_messages = messages.len(),
117    temperature,
118    max_tokens,
119    endpoint = %endpoint
120))]
121async fn make_llm_request(
122    messages: Vec<Message>,
123    temperature: f32,
124    max_tokens: Option<i32>,
125    endpoint: &url::Url,
126    api_key: &str,
127) -> anyhow::Result<LlmCompletionResponse> {
128    debug!("Preparing LLM request with {} messages", messages.len());
129    let base_request = BaseLlmRequest {
130        messages,
131        temperature,
132        max_tokens,
133    };
134
135    trace!("Base request prepared: {:?}", base_request);
136
137    let request = AzureCompletionRequest {
138        base: base_request,
139        stream: false,
140    };
141
142    let headers = build_llm_headers(api_key)?;
143    debug!("Sending request to LLM endpoint: {}", endpoint);
144
145    let response = REQWEST_CLIENT
146        .post(prepare_azure_endpoint(endpoint.clone()))
147        .headers(headers)
148        .json(&request)
149        .send()
150        .await?;
151
152    trace!("Received response from LLM");
153    process_llm_response(response).await
154}
155
156/// Process a non-streaming LLM response
157#[instrument(skip(response), fields(status = %response.status()))]
158async fn process_llm_response(response: Response) -> anyhow::Result<LlmCompletionResponse> {
159    if !response.status().is_success() {
160        let status = response.status();
161        let error_text = response.text().await?;
162        error!(
163            status = %status,
164            error = %error_text,
165            "Error calling LLM API"
166        );
167        return Err(anyhow::anyhow!(
168            "Error calling LLM API: Status: {}. Error: {}",
169            status,
170            error_text
171        ));
172    }
173
174    trace!("Processing successful LLM response");
175    // Parse the response
176    let completion: LlmCompletionResponse = response.json().await?;
177    debug!(
178        "Successfully processed LLM response with {} choices",
179        completion.choices.len()
180    );
181    Ok(completion)
182}
183
184/// Makes a streaming request to an LLM
185#[instrument(skip(messages, app_config), fields(
186    num_messages = messages.len(),
187    temperature,
188    max_tokens
189))]
190pub async fn make_streaming_llm_request(
191    messages: Vec<Message>,
192    temperature: f32,
193    max_tokens: Option<i32>,
194    app_config: &ApplicationConfiguration,
195) -> anyhow::Result<Response> {
196    debug!(
197        "Preparing streaming LLM request with {} messages",
198        messages.len()
199    );
200    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
201        error!("Azure configuration missing");
202        anyhow::anyhow!("Azure configuration is missing from the application configuration")
203    })?;
204
205    let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
206        error!("Chatbot configuration missing");
207        anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
208    })?;
209
210    let base_request = BaseLlmRequest {
211        messages,
212        temperature,
213        max_tokens,
214    };
215
216    trace!("Base request prepared: {:?}", base_request);
217
218    let request = AzureCompletionRequest {
219        base: base_request,
220        stream: true,
221    };
222
223    let headers = build_llm_headers(&chatbot_config.api_key)?;
224    debug!(
225        "Sending streaming request to LLM endpoint: {}",
226        chatbot_config.api_endpoint
227    );
228
229    dbg!(&request, &headers, &chatbot_config.api_endpoint);
230
231    let response = REQWEST_CLIENT
232        .post(prepare_azure_endpoint(chatbot_config.api_endpoint.clone()))
233        .headers(headers)
234        .json(&request)
235        .send()
236        .await?;
237
238    if !response.status().is_success() {
239        let status = response.status();
240        let error_text = response.text().await?;
241        error!(
242            status = %status,
243            error = %error_text,
244            "Error calling streaming LLM API"
245        );
246        return Err(anyhow::anyhow!(
247            "Error calling LLM API: Status: {}. Error: {}",
248            status,
249            error_text
250        ));
251    }
252
253    debug!("Successfully initiated streaming response");
254    Ok(response)
255}
256
257/// Makes a non-streaming request to an LLM using application configuration
258#[instrument(skip(messages, app_config), fields(
259    num_messages = messages.len(),
260    temperature,
261    max_tokens
262))]
263pub async fn make_blocking_llm_request(
264    messages: Vec<Message>,
265    temperature: f32,
266    max_tokens: Option<i32>,
267    app_config: &ApplicationConfiguration,
268) -> anyhow::Result<LlmCompletionResponse> {
269    debug!(
270        "Preparing blocking LLM request with {} messages",
271        messages.len()
272    );
273    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
274        error!("Azure configuration missing");
275        anyhow::anyhow!("Azure configuration is missing from the application configuration")
276    })?;
277
278    let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
279        error!("Chatbot configuration missing");
280        anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
281    })?;
282
283    trace!(
284        "Making LLM request to endpoint: {}",
285        chatbot_config.api_endpoint
286    );
287    make_llm_request(
288        messages,
289        temperature,
290        max_tokens,
291        &chatbot_config.api_endpoint,
292        &chatbot_config.api_key,
293    )
294    .await
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_estimate_tokens() {
303        // The real number is 4
304        assert_eq!(estimate_tokens("Hello, world!"), 3);
305        assert_eq!(estimate_tokens(""), 0);
306        // The real number is 9
307        assert_eq!(
308            estimate_tokens("This is a longer sentence with several words."),
309            11
310        );
311        // The real number is 7
312        assert_eq!(estimate_tokens("Hyvää päivää!"), 7);
313        // The real number is 9
314        assert_eq!(estimate_tokens("トークンは楽しい"), 12);
315        // The real number is 52
316        assert_eq!(
317            estimate_tokens("🙂🙃😀😃😄😁😆😅😂🤣😊😇🙂🙃😀😃😄😁😆😅😂🤣😊😇"),
318            48
319        );
320        // The real number is 18
321        assert_eq!(estimate_tokens("ฉันใช้โทเค็นทุกวัน"), 27);
322        // The real number is 17
323        assert_eq!(estimate_tokens("Жетони роблять мене щасливим"), 25);
324    }
325}