headless_lms_chatbot/
llm_utils.rs

1use crate::{azure_chatbot::LLMRequest, prelude::*};
2use headless_lms_utils::ApplicationConfiguration;
3use reqwest::Response;
4use reqwest::header::HeaderMap;
5use serde::{Deserialize, Serialize};
6use tracing::{debug, error, instrument, trace, warn};
7
8// API version for Azure OpenAI calls
9pub const LLM_API_VERSION: &str = "2024-06-01";
10
11/// Role of a message in a conversation
12#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
13pub enum MessageRole {
14    #[serde(rename = "system")]
15    System,
16    #[serde(rename = "user")]
17    User,
18    #[serde(rename = "assistant")]
19    Assistant,
20}
21
22/// Common message structure used for LLM API requests
23#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct APIMessage {
25    pub role: MessageRole,
26    pub content: String,
27}
28
29/// Simple completion-focused LLM request for Azure OpenAI
30/// Note: In Azure OpenAI, the model is specified in the URL, not in the request body
31#[derive(Serialize, Deserialize, Debug)]
32pub struct AzureCompletionRequest {
33    #[serde(flatten)]
34    pub base: LLMRequest,
35    pub stream: bool,
36}
37
38/// Response from LLM for simple completions
39#[derive(Deserialize, Debug)]
40pub struct LLMCompletionResponse {
41    pub choices: Vec<LLMChoice>,
42}
43
44#[derive(Deserialize, Debug)]
45pub struct LLMChoice {
46    pub message: APIMessage,
47}
48
49/// Builds common headers for LLM requests
50#[instrument(skip(api_key), fields(api_key_length = api_key.len()))]
51pub fn build_llm_headers(api_key: &str) -> anyhow::Result<HeaderMap> {
52    trace!("Building LLM request headers");
53    let mut headers = HeaderMap::new();
54    headers.insert(
55        "api-key",
56        api_key.parse().map_err(|_e| {
57            error!("Failed to parse API key");
58            anyhow::anyhow!("Invalid API key")
59        })?,
60    );
61    headers.insert(
62        "content-type",
63        "application/json".parse().map_err(|_e| {
64            error!("Failed to parse content-type header");
65            anyhow::anyhow!("Internal error")
66        })?,
67    );
68    trace!("Successfully built headers");
69    Ok(headers)
70}
71
72/// Prepares Azure OpenAI endpoint with API version
73#[instrument(skip(endpoint))]
74pub fn prepare_azure_endpoint(mut endpoint: url::Url) -> url::Url {
75    trace!(
76        "Preparing Azure endpoint with API version {}",
77        LLM_API_VERSION
78    );
79    // Always set the API version so that we actually use the API that the code is written for
80    endpoint.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
81    trace!("Endpoint prepared: {}", endpoint);
82    endpoint
83}
84
85/// Estimate the number of tokens in a given text.
86#[instrument(skip(text), fields(text_length = text.len()))]
87pub fn estimate_tokens(text: &str) -> i32 {
88    trace!("Estimating tokens for text");
89    let text_length = text.chars().fold(0, |acc, c| {
90        let mut len = c.len_utf8() as i32;
91        if len > 1 {
92            // The longer the character is, the more likely the text around is taking up more tokens
93            len *= 2;
94        }
95        if c.is_ascii_punctuation() {
96            // Punctuation is less common and is thus less likely to be part of a token
97            len *= 2;
98        }
99        acc + len
100    });
101    // A token is roughly 4 characters
102    let estimated_tokens = text_length / 4;
103    trace!("Estimated {} tokens for text", estimated_tokens);
104    estimated_tokens
105}
106
107/// Makes a non-streaming request to an LLM
108#[instrument(skip(chat_request, endpoint, api_key), fields(
109    num_messages = chat_request.messages.len(),
110    temperature,
111    max_tokens,
112    endpoint = %endpoint
113))]
114async fn make_llm_request(
115    chat_request: LLMRequest,
116    endpoint: &url::Url,
117    api_key: &str,
118) -> anyhow::Result<LLMCompletionResponse> {
119    debug!(
120        "Preparing LLM request with {} messages",
121        chat_request.messages.len()
122    );
123
124    trace!("Base request: {:?}", chat_request);
125
126    let request = AzureCompletionRequest {
127        base: chat_request,
128        stream: false,
129    };
130
131    let headers = build_llm_headers(api_key)?;
132    debug!("Sending request to LLM endpoint: {}", endpoint);
133
134    let response = REQWEST_CLIENT
135        .post(prepare_azure_endpoint(endpoint.clone()))
136        .headers(headers)
137        .json(&request)
138        .send()
139        .await?;
140
141    trace!("Received response from LLM");
142    process_llm_response(response).await
143}
144
145/// Process a non-streaming LLM response
146#[instrument(skip(response), fields(status = %response.status()))]
147async fn process_llm_response(response: Response) -> anyhow::Result<LLMCompletionResponse> {
148    if !response.status().is_success() {
149        let status = response.status();
150        let error_text = response.text().await?;
151        error!(
152            status = %status,
153            error = %error_text,
154            "Error calling LLM API"
155        );
156        return Err(anyhow::anyhow!(
157            "Error calling LLM API: Status: {}. Error: {}",
158            status,
159            error_text
160        ));
161    }
162
163    trace!("Processing successful LLM response");
164    // Parse the response
165    let completion: LLMCompletionResponse = response.json().await?;
166    debug!(
167        "Successfully processed LLM response with {} choices",
168        completion.choices.len()
169    );
170    Ok(completion)
171}
172
173/// Makes a streaming request to an LLM
174#[instrument(skip(chat_request, app_config), fields(
175    num_messages = chat_request.messages.len(),
176    temperature,
177    max_tokens
178))]
179pub async fn make_streaming_llm_request(
180    chat_request: LLMRequest,
181    model_deployment_name: &str,
182    app_config: &ApplicationConfiguration,
183) -> anyhow::Result<Response> {
184    debug!(
185        "Preparing streaming LLM request with {} messages",
186        chat_request.messages.len()
187    );
188    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
189        error!("Azure configuration missing");
190        anyhow::anyhow!("Azure configuration is missing from the application configuration")
191    })?;
192
193    let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
194        error!("Chatbot configuration missing");
195        anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
196    })?;
197
198    trace!("Base request: {:?}", chat_request);
199
200    let request = AzureCompletionRequest {
201        base: chat_request,
202        stream: true,
203    };
204
205    let headers = build_llm_headers(&chatbot_config.api_key)?;
206    let api_endpoint = chatbot_config
207        .api_endpoint
208        .join(&(model_deployment_name.to_owned() + "/chat/completions"))?;
209    debug!(
210        "Sending streaming request to LLM endpoint: {}",
211        api_endpoint
212    );
213
214    let response = REQWEST_CLIENT
215        .post(prepare_azure_endpoint(api_endpoint.clone()))
216        .headers(headers)
217        .json(&request)
218        .send()
219        .await?;
220
221    if !response.status().is_success() {
222        let status = response.status();
223        let error_text = response.text().await?;
224        error!(
225            status = %status,
226            error = %error_text,
227            "Error calling streaming LLM API"
228        );
229        return Err(anyhow::anyhow!(
230            "Error calling LLM API: Status: {}. Error: {}",
231            status,
232            error_text
233        ));
234    }
235
236    debug!("Successfully initiated streaming response");
237    Ok(response)
238}
239
240/// Makes a non-streaming request to an LLM using application configuration
241#[instrument(skip(chat_request, app_config), fields(
242    num_messages = chat_request.messages.len(),
243    temperature,
244    max_tokens
245))]
246pub async fn make_blocking_llm_request(
247    chat_request: LLMRequest,
248    app_config: &ApplicationConfiguration,
249) -> anyhow::Result<LLMCompletionResponse> {
250    debug!(
251        "Preparing blocking LLM request with {} messages",
252        chat_request.messages.len()
253    );
254    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
255        error!("Azure configuration missing");
256        anyhow::anyhow!("Azure configuration is missing from the application configuration")
257    })?;
258
259    let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
260        error!("Chatbot configuration missing");
261        anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
262    })?;
263
264    let api_endpoint = chatbot_config
265        .api_endpoint
266        .join("gpt-4o/chat/completions")?;
267
268    trace!("Making LLM request to endpoint: {}", api_endpoint);
269    make_llm_request(chat_request, &api_endpoint, &chatbot_config.api_key).await
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_estimate_tokens() {
278        // The real number is 4
279        assert_eq!(estimate_tokens("Hello, world!"), 3);
280        assert_eq!(estimate_tokens(""), 0);
281        // The real number is 9
282        assert_eq!(
283            estimate_tokens("This is a longer sentence with several words."),
284            11
285        );
286        // The real number is 7
287        assert_eq!(estimate_tokens("Hyvää päivää!"), 7);
288        // The real number is 9
289        assert_eq!(estimate_tokens("トークンは楽しい"), 12);
290        // The real number is 52
291        assert_eq!(
292            estimate_tokens("🙂🙃😀😃😄😁😆😅😂🤣😊😇🙂🙃😀😃😄😁😆😅😂🤣😊😇"),
293            48
294        );
295        // The real number is 18
296        assert_eq!(estimate_tokens("ฉันใช้โทเค็นทุกวัน"), 27);
297        // The real number is 17
298        assert_eq!(estimate_tokens("Жетони роблять мене щасливим"), 25);
299    }
300}