headless_lms_models/
chatbot_conversation_messages.rs

1use std::fmt;
2
3use crate::{
4    chatbot_conversation_message_tool_calls::{self, ChatbotConversationMessageToolCall},
5    chatbot_conversation_message_tool_outputs::{self, ChatbotConversationMessageToolOutput},
6    prelude::*,
7};
8
9#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy, Type)]
10#[cfg_attr(feature = "ts_rs", derive(TS))]
11#[sqlx(type_name = "message_role", rename_all = "snake_case")]
12#[serde(rename_all = "snake_case")]
13pub enum MessageRole {
14    Assistant,
15    User,
16    Tool,
17    System,
18}
19
20impl fmt::Display for MessageRole {
21    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22        write!(f, "{:?}", self)
23    }
24}
25
26#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
27pub struct ChatbotConversationMessageRow {
28    pub id: Uuid,
29    pub created_at: DateTime<Utc>,
30    pub updated_at: DateTime<Utc>,
31    pub deleted_at: Option<DateTime<Utc>>,
32    pub conversation_id: Uuid,
33    pub message: Option<String>,
34    pub message_role: MessageRole,
35    pub message_is_complete: bool,
36    pub used_tokens: i32,
37    pub order_number: i32,
38    pub tool_output_id: Option<Uuid>,
39}
40
41#[derive(Clone, PartialEq, Deserialize, Serialize, Debug)]
42#[cfg_attr(feature = "ts_rs", derive(TS))]
43pub struct ChatbotConversationMessage {
44    pub id: Uuid,
45    pub created_at: DateTime<Utc>,
46    pub updated_at: DateTime<Utc>,
47    pub deleted_at: Option<DateTime<Utc>>,
48    pub conversation_id: Uuid,
49    pub message: Option<String>,
50    pub message_role: MessageRole,
51    pub message_is_complete: bool,
52    pub used_tokens: i32,
53    pub order_number: i32,
54    pub tool_output: Option<ChatbotConversationMessageToolOutput>,
55    pub tool_call_fields: Vec<ChatbotConversationMessageToolCall>,
56}
57
58impl Default for ChatbotConversationMessage {
59    fn default() -> Self {
60        Self {
61            id: Uuid::nil(),
62            created_at: Default::default(),
63            updated_at: Default::default(),
64            deleted_at: None,
65            conversation_id: Uuid::nil(),
66            message: Default::default(),
67            message_role: MessageRole::System,
68            message_is_complete: false,
69            used_tokens: Default::default(),
70            order_number: Default::default(),
71            tool_output: None,
72            tool_call_fields: Default::default(),
73        }
74    }
75}
76
77impl ChatbotConversationMessage {
78    pub fn from_row(
79        r: ChatbotConversationMessageRow,
80        o: Option<ChatbotConversationMessageToolOutput>,
81        c: Vec<ChatbotConversationMessageToolCall>,
82    ) -> Self {
83        ChatbotConversationMessage {
84            id: r.id,
85            created_at: r.created_at,
86            updated_at: r.updated_at,
87            deleted_at: r.deleted_at,
88            conversation_id: r.conversation_id,
89            message: r.message,
90            message_role: r.message_role,
91            message_is_complete: r.message_is_complete,
92            used_tokens: r.used_tokens,
93            order_number: r.order_number,
94            tool_output: o,
95            tool_call_fields: c,
96        }
97    }
98}
99
100pub async fn insert(
101    conn: &mut PgConnection,
102    input: ChatbotConversationMessage,
103) -> ModelResult<ChatbotConversationMessage> {
104    let mut tx = conn.begin().await?;
105    let msg = sqlx::query_as!(
106        ChatbotConversationMessageRow,
107        r#"
108INSERT INTO chatbot_conversation_messages (
109    conversation_id,
110    message,
111    message_role,
112    message_is_complete,
113    used_tokens,
114    order_number,
115    tool_output_id
116)
117VALUES ($1, $2, $3, $4, $5, $6, $7)
118RETURNING
119    id,
120    created_at,
121    updated_at,
122    deleted_at,
123    conversation_id,
124    message,
125    message_role as "message_role: MessageRole",
126    message_is_complete,
127    used_tokens,
128    order_number,
129    tool_output_id
130        "#,
131        input.conversation_id,
132        input.message,
133        input.message_role as MessageRole,
134        input.message_is_complete,
135        input.used_tokens,
136        input.order_number,
137        None::<Uuid>,
138    )
139    .fetch_one(&mut *tx)
140    .await?;
141
142    let (tool_output_id, tool_output) = match msg.message_role {
143        MessageRole::Assistant => {
144            if msg.message.is_some() {
145                (None, None)
146            } else if !input.tool_call_fields.is_empty() {
147                chatbot_conversation_message_tool_calls::insert_batch(
148                    &mut tx,
149                    input.tool_call_fields.to_owned(),
150                    msg.id,
151                )
152                .await?;
153                (None, None)
154            } else {
155                return ModelResult::Err(ModelError::new(
156                    ModelErrorType::InvalidRequest,
157                    "A chatbot conversation message with role 'assistant' has to have either a message or tool calls",
158                    None,
159                ));
160            }
161        }
162        MessageRole::Tool => {
163            if let Some(output) = input.tool_output {
164                let o_res =
165                    chatbot_conversation_message_tool_outputs::insert(&mut tx, output, msg.id)
166                        .await?;
167                (Some(o_res.id), Some(o_res))
168            } else {
169                return ModelResult::Err(ModelError::new(
170                    ModelErrorType::InvalidRequest,
171                    "A chatbot conversation message with role 'tool' must have tool output",
172                    None,
173                ));
174            }
175        }
176        MessageRole::User => (None, None),
177        MessageRole::System => {
178            return ModelResult::Err(ModelError::new(
179                ModelErrorType::InvalidRequest,
180                "Can't save system message to database",
181                None,
182            ));
183        }
184    };
185
186    // Update the message to contain the tool_output_id if it was created
187    if tool_output_id.is_some() {
188        sqlx::query_as!(
189            ChatbotConversationMessageRow,
190            r#"
191UPDATE chatbot_conversation_messages
192SET tool_output_id = $1
193WHERE id = $2
194            "#,
195            tool_output_id,
196            msg.id,
197        )
198        .execute(&mut *tx)
199        .await?;
200    }
201
202    let res = ChatbotConversationMessage::from_row(msg, tool_output, input.tool_call_fields);
203    tx.commit().await?;
204    Ok(res)
205}
206
207pub async fn get_by_conversation_id(
208    conn: &mut PgConnection,
209    conversation_id: Uuid,
210) -> ModelResult<Vec<ChatbotConversationMessage>> {
211    let mut tx = conn.begin().await?;
212    let mut msgs: Vec<ChatbotConversationMessageRow> = sqlx::query_as!(
213        ChatbotConversationMessageRow,
214        r#"
215SELECT
216    id,
217    created_at,
218    updated_at,
219    deleted_at,
220    conversation_id,
221    message,
222    message_role as "message_role: MessageRole",
223    message_is_complete,
224    used_tokens,
225    order_number,
226    tool_output_id
227FROM chatbot_conversation_messages
228WHERE conversation_id = $1
229AND deleted_at IS NULL
230        "#,
231        conversation_id
232    )
233    .fetch_all(&mut *tx)
234    .await?;
235    // Should have the same order as in the conversation.
236    msgs.sort_by(|a, b| a.order_number.cmp(&b.order_number));
237    let mut res = vec![];
238    for m in msgs {
239        let msg = message_row_to_message(&mut tx, m).await?;
240        res.push(msg);
241    }
242    tx.commit().await?;
243    Ok(res)
244}
245
246pub async fn update(
247    conn: &mut PgConnection,
248    id: Uuid,
249    message: &str,
250    message_is_complete: bool,
251    used_tokens: i32,
252) -> ModelResult<ChatbotConversationMessage> {
253    let mut tx = conn.begin().await?;
254    let row = sqlx::query_as!(
255        ChatbotConversationMessageRow,
256        r#"
257UPDATE chatbot_conversation_messages
258SET message = $2, message_is_complete = $3, used_tokens = $4
259WHERE id = $1
260RETURNING
261    id,
262    created_at,
263    updated_at,
264    deleted_at,
265    conversation_id,
266    message,
267    message_role as "message_role: MessageRole",
268    message_is_complete,
269    used_tokens,
270    order_number,
271    tool_output_id
272        "#,
273        id,
274        Some(message),
275        message_is_complete,
276        used_tokens
277    )
278    .fetch_one(&mut *tx)
279    .await?;
280
281    let res = message_row_to_message(&mut tx, row).await?;
282    tx.commit().await?;
283    Ok(res)
284}
285
286pub async fn delete(conn: &mut PgConnection, id: Uuid) -> ModelResult<ChatbotConversationMessage> {
287    let mut tx = conn.begin().await?;
288
289    let row = sqlx::query_as!(
290        ChatbotConversationMessageRow,
291        r#"
292UPDATE chatbot_conversation_messages
293SET deleted_at = NOW()
294WHERE id = $1
295RETURNING
296    id,
297    created_at,
298    updated_at,
299    deleted_at,
300    conversation_id,
301    message,
302    message_role as "message_role: MessageRole",
303    message_is_complete,
304    used_tokens,
305    order_number,
306    tool_output_id
307        "#,
308        id
309    )
310    .fetch_one(&mut *tx)
311    .await?;
312
313    if let Some(output_id) = row.tool_output_id {
314        chatbot_conversation_message_tool_outputs::delete(&mut tx, output_id).await?;
315    }
316    chatbot_conversation_message_tool_calls::delete_all_by_message_id(&mut tx, row.id).await?;
317
318    let res = message_row_to_message(&mut tx, row).await?;
319    tx.commit().await?;
320    Ok(res)
321}
322
323pub async fn message_row_to_message(
324    conn: &mut PgConnection,
325    row: ChatbotConversationMessageRow,
326) -> ModelResult<ChatbotConversationMessage> {
327    let o = if let Some(id) = row.tool_output_id {
328        Some(chatbot_conversation_message_tool_outputs::get_by_id(conn, id).await?)
329    } else {
330        None
331    };
332    let c = chatbot_conversation_message_tool_calls::get_by_message_id(conn, row.id).await?;
333    let res = ChatbotConversationMessage::from_row(row, o, c);
334    Ok(res)
335}