Skip to main content

headless_lms_models/
chatbot_conversation_messages.rs

1use std::fmt;
2use utoipa::ToSchema;
3
4use crate::{
5    chatbot_conversation_message_tool_calls::{self, ChatbotConversationMessageToolCall},
6    chatbot_conversation_message_tool_outputs::{self, ChatbotConversationMessageToolOutput},
7    prelude::*,
8};
9
10#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy, Type, ToSchema)]
11#[sqlx(type_name = "message_role", rename_all = "snake_case")]
12#[serde(rename_all = "snake_case")]
13pub enum MessageRole {
14    Assistant,
15    User,
16    Tool,
17    System,
18}
19
20impl fmt::Display for MessageRole {
21    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22        write!(f, "{:?}", self)
23    }
24}
25
26#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
27pub struct ChatbotConversationMessageRow {
28    pub id: Uuid,
29    pub created_at: DateTime<Utc>,
30    pub updated_at: DateTime<Utc>,
31    pub deleted_at: Option<DateTime<Utc>>,
32    pub conversation_id: Uuid,
33    pub message: Option<String>,
34    pub message_role: MessageRole,
35    pub message_is_complete: bool,
36    pub used_tokens: i32,
37    pub order_number: i32,
38    pub tool_output_id: Option<Uuid>,
39}
40
41#[derive(Clone, PartialEq, Deserialize, Serialize, Debug, ToSchema)]
42
43pub struct ChatbotConversationMessage {
44    pub id: Uuid,
45    pub created_at: DateTime<Utc>,
46    pub updated_at: DateTime<Utc>,
47    pub deleted_at: Option<DateTime<Utc>>,
48    pub conversation_id: Uuid,
49    pub message: Option<String>,
50    pub message_role: MessageRole,
51    pub message_is_complete: bool,
52    pub used_tokens: i32,
53    pub order_number: i32,
54    pub tool_output: Option<ChatbotConversationMessageToolOutput>,
55    pub tool_call_fields: Vec<ChatbotConversationMessageToolCall>,
56}
57
58impl Default for ChatbotConversationMessage {
59    fn default() -> Self {
60        Self {
61            id: Uuid::nil(),
62            created_at: Default::default(),
63            updated_at: Default::default(),
64            deleted_at: None,
65            conversation_id: Uuid::nil(),
66            message: Default::default(),
67            message_role: MessageRole::System,
68            message_is_complete: false,
69            used_tokens: Default::default(),
70            order_number: Default::default(),
71            tool_output: None,
72            tool_call_fields: Default::default(),
73        }
74    }
75}
76
77impl ChatbotConversationMessage {
78    pub fn from_row(
79        r: ChatbotConversationMessageRow,
80        o: Option<ChatbotConversationMessageToolOutput>,
81        c: Vec<ChatbotConversationMessageToolCall>,
82    ) -> Self {
83        ChatbotConversationMessage {
84            id: r.id,
85            created_at: r.created_at,
86            updated_at: r.updated_at,
87            deleted_at: r.deleted_at,
88            conversation_id: r.conversation_id,
89            message: r.message,
90            message_role: r.message_role,
91            message_is_complete: r.message_is_complete,
92            used_tokens: r.used_tokens,
93            order_number: r.order_number,
94            tool_output: o,
95            tool_call_fields: c,
96        }
97    }
98}
99
100pub async fn insert(
101    conn: &mut PgConnection,
102    input: ChatbotConversationMessage,
103) -> ModelResult<ChatbotConversationMessage> {
104    let mut tx = conn.begin().await?;
105    let msg = sqlx::query_as!(
106        ChatbotConversationMessageRow,
107        r#"
108INSERT INTO chatbot_conversation_messages (
109    conversation_id,
110    message,
111    message_role,
112    message_is_complete,
113    used_tokens,
114    order_number,
115    tool_output_id
116)
117VALUES ($1, $2, $3, $4, $5, $6, $7)
118RETURNING
119    id,
120    created_at,
121    updated_at,
122    deleted_at,
123    conversation_id,
124    message,
125    message_role as "message_role: MessageRole",
126    message_is_complete,
127    used_tokens,
128    order_number,
129    tool_output_id
130        "#,
131        input.conversation_id,
132        input.message,
133        input.message_role as MessageRole,
134        input.message_is_complete,
135        input.used_tokens,
136        input.order_number,
137        None::<Uuid>,
138    )
139    .fetch_one(&mut *tx)
140    .await?;
141
142    let (tool_output_id, tool_output) = match msg.message_role {
143        MessageRole::Assistant => {
144            if msg.message.is_some() {
145                (None, None)
146            } else if !input.tool_call_fields.is_empty() {
147                chatbot_conversation_message_tool_calls::insert_batch(
148                    &mut tx,
149                    input.tool_call_fields.to_owned(),
150                    msg.id,
151                )
152                .await?;
153                (None, None)
154            } else {
155                return Err(model_err!(
156                    InvalidRequest,
157                    "A chatbot conversation message with role 'assistant' has to have either a message or tool calls"
158                ));
159            }
160        }
161        MessageRole::Tool => {
162            if let Some(output) = input.tool_output {
163                let o_res =
164                    chatbot_conversation_message_tool_outputs::insert(&mut tx, output, msg.id)
165                        .await?;
166                (Some(o_res.id), Some(o_res))
167            } else {
168                return Err(model_err!(
169                    InvalidRequest,
170                    "A chatbot conversation message with role 'tool' must have tool output"
171                ));
172            }
173        }
174        MessageRole::User => (None, None),
175        MessageRole::System => {
176            return Err(model_err!(
177                InvalidRequest,
178                "Can't save system message to database"
179            ));
180        }
181    };
182
183    // Update the message to contain the tool_output_id if it was created
184    if tool_output_id.is_some() {
185        sqlx::query_as!(
186            ChatbotConversationMessageRow,
187            r#"
188UPDATE chatbot_conversation_messages
189SET tool_output_id = $1
190WHERE id = $2
191            "#,
192            tool_output_id,
193            msg.id,
194        )
195        .execute(&mut *tx)
196        .await?;
197    }
198
199    let res = ChatbotConversationMessage::from_row(msg, tool_output, input.tool_call_fields);
200    tx.commit().await?;
201    Ok(res)
202}
203
204pub async fn insert_for_conversation_user_and_configuration(
205    conn: &mut PgConnection,
206    input: ChatbotConversationMessage,
207    user_id: Uuid,
208    chatbot_configuration_id: Uuid,
209) -> ModelResult<ChatbotConversationMessage> {
210    let mut tx = conn.begin().await?;
211
212    sqlx::query!(
213        r#"
214SELECT id
215FROM chatbot_conversations
216WHERE id = $1
217  AND user_id = $2
218  AND chatbot_configuration_id = $3
219  AND deleted_at IS NULL
220        "#,
221        input.conversation_id,
222        user_id,
223        chatbot_configuration_id
224    )
225    .fetch_one(&mut *tx)
226    .await?;
227
228    let msg = sqlx::query_as!(
229        ChatbotConversationMessageRow,
230        r#"
231INSERT INTO chatbot_conversation_messages (
232    conversation_id,
233    message,
234    message_role,
235    message_is_complete,
236    used_tokens,
237    order_number,
238    tool_output_id
239)
240VALUES ($1, $2, $3, $4, $5, $6, $7)
241RETURNING
242    id,
243    created_at,
244    updated_at,
245    deleted_at,
246    conversation_id,
247    message,
248    message_role as "message_role: MessageRole",
249    message_is_complete,
250    used_tokens,
251    order_number,
252    tool_output_id
253        "#,
254        input.conversation_id,
255        input.message,
256        input.message_role as MessageRole,
257        input.message_is_complete,
258        input.used_tokens,
259        input.order_number,
260        None::<Uuid>,
261    )
262    .fetch_one(&mut *tx)
263    .await?;
264
265    let (tool_output_id, tool_output) = match msg.message_role {
266        MessageRole::Assistant => {
267            if msg.message.is_some() {
268                (None, None)
269            } else if !input.tool_call_fields.is_empty() {
270                chatbot_conversation_message_tool_calls::insert_batch(
271                    &mut tx,
272                    input.tool_call_fields.to_owned(),
273                    msg.id,
274                )
275                .await?;
276                (None, None)
277            } else {
278                return ModelResult::Err(ModelError::new(
279                    ModelErrorType::InvalidRequest,
280                    "A chatbot conversation message with role 'assistant' has to have either a message or tool calls",
281                    None,
282                ));
283            }
284        }
285        MessageRole::Tool => {
286            if let Some(output) = input.tool_output {
287                let o_res =
288                    chatbot_conversation_message_tool_outputs::insert(&mut tx, output, msg.id)
289                        .await?;
290                (Some(o_res.id), Some(o_res))
291            } else {
292                return ModelResult::Err(ModelError::new(
293                    ModelErrorType::InvalidRequest,
294                    "A chatbot conversation message with role 'tool' must have tool output",
295                    None,
296                ));
297            }
298        }
299        MessageRole::User => (None, None),
300        MessageRole::System => {
301            return ModelResult::Err(ModelError::new(
302                ModelErrorType::InvalidRequest,
303                "Can't save system message to database",
304                None,
305            ));
306        }
307    };
308
309    if tool_output_id.is_some() {
310        sqlx::query_as!(
311            ChatbotConversationMessageRow,
312            r#"
313UPDATE chatbot_conversation_messages
314SET tool_output_id = $1
315WHERE id = $2
316            "#,
317            tool_output_id,
318            msg.id,
319        )
320        .execute(&mut *tx)
321        .await?;
322    }
323
324    let res = ChatbotConversationMessage::from_row(msg, tool_output, input.tool_call_fields);
325    tx.commit().await?;
326    Ok(res)
327}
328
329pub async fn get_by_conversation_id(
330    conn: &mut PgConnection,
331    conversation_id: Uuid,
332) -> ModelResult<Vec<ChatbotConversationMessage>> {
333    let mut tx = conn.begin().await?;
334    let mut msgs: Vec<ChatbotConversationMessageRow> = sqlx::query_as!(
335        ChatbotConversationMessageRow,
336        r#"
337SELECT
338    id,
339    created_at,
340    updated_at,
341    deleted_at,
342    conversation_id,
343    message,
344    message_role as "message_role: MessageRole",
345    message_is_complete,
346    used_tokens,
347    order_number,
348    tool_output_id
349FROM chatbot_conversation_messages
350WHERE conversation_id = $1
351AND deleted_at IS NULL
352        "#,
353        conversation_id
354    )
355    .fetch_all(&mut *tx)
356    .await?;
357    // Should have the same order as in the conversation.
358    msgs.sort_by_key(|a| a.order_number);
359    let mut res = vec![];
360    for m in msgs {
361        let msg = message_row_to_message(&mut tx, m).await?;
362        res.push(msg);
363    }
364    tx.commit().await?;
365    Ok(res)
366}
367
368pub async fn update(
369    conn: &mut PgConnection,
370    id: Uuid,
371    message: &str,
372    message_is_complete: bool,
373    used_tokens: i32,
374) -> ModelResult<ChatbotConversationMessage> {
375    let mut tx = conn.begin().await?;
376    let row = sqlx::query_as!(
377        ChatbotConversationMessageRow,
378        r#"
379UPDATE chatbot_conversation_messages
380SET message = $2, message_is_complete = $3, used_tokens = $4
381WHERE id = $1
382RETURNING
383    id,
384    created_at,
385    updated_at,
386    deleted_at,
387    conversation_id,
388    message,
389    message_role as "message_role: MessageRole",
390    message_is_complete,
391    used_tokens,
392    order_number,
393    tool_output_id
394        "#,
395        id,
396        Some(message),
397        message_is_complete,
398        used_tokens
399    )
400    .fetch_one(&mut *tx)
401    .await?;
402
403    let res = message_row_to_message(&mut tx, row).await?;
404    tx.commit().await?;
405    Ok(res)
406}
407
408pub async fn delete(conn: &mut PgConnection, id: Uuid) -> ModelResult<ChatbotConversationMessage> {
409    let mut tx = conn.begin().await?;
410
411    let row = sqlx::query_as!(
412        ChatbotConversationMessageRow,
413        r#"
414UPDATE chatbot_conversation_messages
415SET deleted_at = NOW()
416WHERE id = $1
417RETURNING
418    id,
419    created_at,
420    updated_at,
421    deleted_at,
422    conversation_id,
423    message,
424    message_role as "message_role: MessageRole",
425    message_is_complete,
426    used_tokens,
427    order_number,
428    tool_output_id
429        "#,
430        id
431    )
432    .fetch_one(&mut *tx)
433    .await?;
434
435    if let Some(output_id) = row.tool_output_id {
436        chatbot_conversation_message_tool_outputs::delete(&mut tx, output_id).await?;
437    }
438    chatbot_conversation_message_tool_calls::delete_all_by_message_id(&mut tx, row.id).await?;
439
440    let res = message_row_to_message(&mut tx, row).await?;
441    tx.commit().await?;
442    Ok(res)
443}
444
445pub async fn message_row_to_message(
446    conn: &mut PgConnection,
447    row: ChatbotConversationMessageRow,
448) -> ModelResult<ChatbotConversationMessage> {
449    let o = if let Some(id) = row.tool_output_id {
450        Some(chatbot_conversation_message_tool_outputs::get_by_id(conn, id).await?)
451    } else {
452        None
453    };
454    let c = chatbot_conversation_message_tool_calls::get_by_message_id(conn, row.id).await?;
455    let res = ChatbotConversationMessage::from_row(row, o, c);
456    Ok(res)
457}