headless_lms_models/
chatbot_conversation_messages.rs

1use crate::{
2    chatbot_conversation_message_tool_calls::{self, ChatbotConversationMessageToolCall},
3    chatbot_conversation_message_tool_outputs::{self, ChatbotConversationMessageToolOutput},
4    prelude::*,
5};
6
7#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy, Type)]
8#[cfg_attr(feature = "ts_rs", derive(TS))]
9#[sqlx(type_name = "message_role", rename_all = "snake_case")]
10#[serde(rename_all = "snake_case")]
11pub enum MessageRole {
12    Assistant,
13    User,
14    Tool,
15    System,
16}
17
18#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
19pub struct ChatbotConversationMessageRow {
20    pub id: Uuid,
21    pub created_at: DateTime<Utc>,
22    pub updated_at: DateTime<Utc>,
23    pub deleted_at: Option<DateTime<Utc>>,
24    pub conversation_id: Uuid,
25    pub message: Option<String>,
26    pub message_role: MessageRole,
27    pub message_is_complete: bool,
28    pub used_tokens: i32,
29    pub order_number: i32,
30    pub tool_output_id: Option<Uuid>,
31}
32
33#[derive(Clone, PartialEq, Deserialize, Serialize, Debug)]
34#[cfg_attr(feature = "ts_rs", derive(TS))]
35pub struct ChatbotConversationMessage {
36    pub id: Uuid,
37    pub created_at: DateTime<Utc>,
38    pub updated_at: DateTime<Utc>,
39    pub deleted_at: Option<DateTime<Utc>>,
40    pub conversation_id: Uuid,
41    pub message: Option<String>,
42    pub message_role: MessageRole,
43    pub message_is_complete: bool,
44    pub used_tokens: i32,
45    pub order_number: i32,
46    pub tool_output: Option<ChatbotConversationMessageToolOutput>,
47    pub tool_call_fields: Vec<ChatbotConversationMessageToolCall>,
48}
49
50impl Default for ChatbotConversationMessage {
51    fn default() -> Self {
52        Self {
53            id: Uuid::nil(),
54            created_at: Default::default(),
55            updated_at: Default::default(),
56            deleted_at: None,
57            conversation_id: Uuid::nil(),
58            message: Default::default(),
59            message_role: MessageRole::System,
60            message_is_complete: false,
61            used_tokens: Default::default(),
62            order_number: Default::default(),
63            tool_output: None,
64            tool_call_fields: Default::default(),
65        }
66    }
67}
68
69impl ChatbotConversationMessage {
70    pub fn from_row(
71        r: ChatbotConversationMessageRow,
72        o: Option<ChatbotConversationMessageToolOutput>,
73        c: Vec<ChatbotConversationMessageToolCall>,
74    ) -> Self {
75        ChatbotConversationMessage {
76            id: r.id,
77            created_at: r.created_at,
78            updated_at: r.updated_at,
79            deleted_at: r.deleted_at,
80            conversation_id: r.conversation_id,
81            message: r.message,
82            message_role: r.message_role,
83            message_is_complete: r.message_is_complete,
84            used_tokens: r.used_tokens,
85            order_number: r.order_number,
86            tool_output: o,
87            tool_call_fields: c,
88        }
89    }
90}
91
92pub async fn insert(
93    conn: &mut PgConnection,
94    input: ChatbotConversationMessage,
95) -> ModelResult<ChatbotConversationMessage> {
96    let mut tx = conn.begin().await?;
97    let msg = sqlx::query_as!(
98        ChatbotConversationMessageRow,
99        r#"
100INSERT INTO chatbot_conversation_messages (
101    conversation_id,
102    message,
103    message_role,
104    message_is_complete,
105    used_tokens,
106    order_number,
107    tool_output_id
108)
109VALUES ($1, $2, $3, $4, $5, $6, $7)
110RETURNING
111    id,
112    created_at,
113    updated_at,
114    deleted_at,
115    conversation_id,
116    message,
117    message_role as "message_role: MessageRole",
118    message_is_complete,
119    used_tokens,
120    order_number,
121    tool_output_id
122        "#,
123        input.conversation_id,
124        input.message,
125        input.message_role as MessageRole,
126        input.message_is_complete,
127        input.used_tokens,
128        input.order_number,
129        None::<Uuid>,
130    )
131    .fetch_one(&mut *tx)
132    .await?;
133
134    let (tool_output_id, tool_output) = match msg.message_role {
135        MessageRole::Assistant => {
136            if msg.message.is_some() {
137                (None, None)
138            } else if !input.tool_call_fields.is_empty() {
139                chatbot_conversation_message_tool_calls::insert_batch(
140                    &mut tx,
141                    input.tool_call_fields.to_owned(),
142                    msg.id,
143                )
144                .await?;
145                (None, None)
146            } else {
147                return ModelResult::Err(ModelError::new(
148                    ModelErrorType::InvalidRequest,
149                    "A chatbot conversation message with role 'assistant' has to have either a message or tool calls",
150                    None,
151                ));
152            }
153        }
154        MessageRole::Tool => {
155            if let Some(output) = input.tool_output {
156                let o_res =
157                    chatbot_conversation_message_tool_outputs::insert(&mut tx, output, msg.id)
158                        .await?;
159                (Some(o_res.id), Some(o_res))
160            } else {
161                return ModelResult::Err(ModelError::new(
162                    ModelErrorType::InvalidRequest,
163                    "A chatbot conversation message with role 'tool' must have tool output",
164                    None,
165                ));
166            }
167        }
168        MessageRole::User => (None, None),
169        MessageRole::System => {
170            return ModelResult::Err(ModelError::new(
171                ModelErrorType::InvalidRequest,
172                "Can't save system message to database",
173                None,
174            ));
175        }
176    };
177
178    // Update the message to contain the tool_output_id if it was created
179    if tool_output_id.is_some() {
180        sqlx::query_as!(
181            ChatbotConversationMessageRow,
182            r#"
183UPDATE chatbot_conversation_messages
184SET tool_output_id = $1
185WHERE id = $2
186            "#,
187            tool_output_id,
188            msg.id,
189        )
190        .execute(&mut *tx)
191        .await?;
192    }
193
194    let res = ChatbotConversationMessage::from_row(msg, tool_output, input.tool_call_fields);
195    tx.commit().await?;
196    Ok(res)
197}
198
199pub async fn get_by_conversation_id(
200    conn: &mut PgConnection,
201    conversation_id: Uuid,
202) -> ModelResult<Vec<ChatbotConversationMessage>> {
203    let mut tx = conn.begin().await?;
204    let mut msgs: Vec<ChatbotConversationMessageRow> = sqlx::query_as!(
205        ChatbotConversationMessageRow,
206        r#"
207SELECT
208    id,
209    created_at,
210    updated_at,
211    deleted_at,
212    conversation_id,
213    message,
214    message_role as "message_role: MessageRole",
215    message_is_complete,
216    used_tokens,
217    order_number,
218    tool_output_id
219FROM chatbot_conversation_messages
220WHERE conversation_id = $1
221AND deleted_at IS NULL
222        "#,
223        conversation_id
224    )
225    .fetch_all(&mut *tx)
226    .await?;
227    // Should have the same order as in the conversation.
228    msgs.sort_by(|a, b| a.order_number.cmp(&b.order_number));
229    let mut res = vec![];
230    for m in msgs {
231        let msg = message_row_to_message(&mut tx, m).await?;
232        res.push(msg);
233    }
234    tx.commit().await?;
235    Ok(res)
236}
237
238pub async fn update(
239    conn: &mut PgConnection,
240    id: Uuid,
241    message: &str,
242    message_is_complete: bool,
243    used_tokens: i32,
244) -> ModelResult<ChatbotConversationMessage> {
245    let mut tx = conn.begin().await?;
246    let row = sqlx::query_as!(
247        ChatbotConversationMessageRow,
248        r#"
249UPDATE chatbot_conversation_messages
250SET message = $2, message_is_complete = $3, used_tokens = $4
251WHERE id = $1
252RETURNING
253    id,
254    created_at,
255    updated_at,
256    deleted_at,
257    conversation_id,
258    message,
259    message_role as "message_role: MessageRole",
260    message_is_complete,
261    used_tokens,
262    order_number,
263    tool_output_id
264        "#,
265        id,
266        Some(message),
267        message_is_complete,
268        used_tokens
269    )
270    .fetch_one(&mut *tx)
271    .await?;
272
273    let res = message_row_to_message(&mut tx, row).await?;
274    tx.commit().await?;
275    Ok(res)
276}
277
278pub async fn delete(conn: &mut PgConnection, id: Uuid) -> ModelResult<ChatbotConversationMessage> {
279    let mut tx = conn.begin().await?;
280
281    let row = sqlx::query_as!(
282        ChatbotConversationMessageRow,
283        r#"
284UPDATE chatbot_conversation_messages
285SET deleted_at = NOW()
286WHERE id = $1
287RETURNING
288    id,
289    created_at,
290    updated_at,
291    deleted_at,
292    conversation_id,
293    message,
294    message_role as "message_role: MessageRole",
295    message_is_complete,
296    used_tokens,
297    order_number,
298    tool_output_id
299        "#,
300        id
301    )
302    .fetch_one(&mut *tx)
303    .await?;
304
305    if let Some(output_id) = row.tool_output_id {
306        chatbot_conversation_message_tool_outputs::delete(&mut tx, output_id).await?;
307    }
308    chatbot_conversation_message_tool_calls::delete_all_by_message_id(&mut tx, row.id).await?;
309
310    let res = message_row_to_message(&mut tx, row).await?;
311    tx.commit().await?;
312    Ok(res)
313}
314
315pub async fn message_row_to_message(
316    conn: &mut PgConnection,
317    row: ChatbotConversationMessageRow,
318) -> ModelResult<ChatbotConversationMessage> {
319    let o = if let Some(id) = row.tool_output_id {
320        Some(chatbot_conversation_message_tool_outputs::get_by_id(conn, id).await?)
321    } else {
322        None
323    };
324    let c = chatbot_conversation_message_tool_calls::get_by_message_id(conn, row.id).await?;
325    let res = ChatbotConversationMessage::from_row(row, o, c);
326    Ok(res)
327}