1use std::fmt;
2
3use crate::{
4 chatbot_conversation_message_tool_calls::{self, ChatbotConversationMessageToolCall},
5 chatbot_conversation_message_tool_outputs::{self, ChatbotConversationMessageToolOutput},
6 prelude::*,
7};
8
9#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy, Type)]
10#[cfg_attr(feature = "ts_rs", derive(TS))]
11#[sqlx(type_name = "message_role", rename_all = "snake_case")]
12#[serde(rename_all = "snake_case")]
13pub enum MessageRole {
14 Assistant,
15 User,
16 Tool,
17 System,
18}
19
20impl fmt::Display for MessageRole {
21 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22 write!(f, "{:?}", self)
23 }
24}
25
26#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
27pub struct ChatbotConversationMessageRow {
28 pub id: Uuid,
29 pub created_at: DateTime<Utc>,
30 pub updated_at: DateTime<Utc>,
31 pub deleted_at: Option<DateTime<Utc>>,
32 pub conversation_id: Uuid,
33 pub message: Option<String>,
34 pub message_role: MessageRole,
35 pub message_is_complete: bool,
36 pub used_tokens: i32,
37 pub order_number: i32,
38 pub tool_output_id: Option<Uuid>,
39}
40
41#[derive(Clone, PartialEq, Deserialize, Serialize, Debug)]
42#[cfg_attr(feature = "ts_rs", derive(TS))]
43pub struct ChatbotConversationMessage {
44 pub id: Uuid,
45 pub created_at: DateTime<Utc>,
46 pub updated_at: DateTime<Utc>,
47 pub deleted_at: Option<DateTime<Utc>>,
48 pub conversation_id: Uuid,
49 pub message: Option<String>,
50 pub message_role: MessageRole,
51 pub message_is_complete: bool,
52 pub used_tokens: i32,
53 pub order_number: i32,
54 pub tool_output: Option<ChatbotConversationMessageToolOutput>,
55 pub tool_call_fields: Vec<ChatbotConversationMessageToolCall>,
56}
57
58impl Default for ChatbotConversationMessage {
59 fn default() -> Self {
60 Self {
61 id: Uuid::nil(),
62 created_at: Default::default(),
63 updated_at: Default::default(),
64 deleted_at: None,
65 conversation_id: Uuid::nil(),
66 message: Default::default(),
67 message_role: MessageRole::System,
68 message_is_complete: false,
69 used_tokens: Default::default(),
70 order_number: Default::default(),
71 tool_output: None,
72 tool_call_fields: Default::default(),
73 }
74 }
75}
76
77impl ChatbotConversationMessage {
78 pub fn from_row(
79 r: ChatbotConversationMessageRow,
80 o: Option<ChatbotConversationMessageToolOutput>,
81 c: Vec<ChatbotConversationMessageToolCall>,
82 ) -> Self {
83 ChatbotConversationMessage {
84 id: r.id,
85 created_at: r.created_at,
86 updated_at: r.updated_at,
87 deleted_at: r.deleted_at,
88 conversation_id: r.conversation_id,
89 message: r.message,
90 message_role: r.message_role,
91 message_is_complete: r.message_is_complete,
92 used_tokens: r.used_tokens,
93 order_number: r.order_number,
94 tool_output: o,
95 tool_call_fields: c,
96 }
97 }
98}
99
100pub async fn insert(
101 conn: &mut PgConnection,
102 input: ChatbotConversationMessage,
103) -> ModelResult<ChatbotConversationMessage> {
104 let mut tx = conn.begin().await?;
105 let msg = sqlx::query_as!(
106 ChatbotConversationMessageRow,
107 r#"
108INSERT INTO chatbot_conversation_messages (
109 conversation_id,
110 message,
111 message_role,
112 message_is_complete,
113 used_tokens,
114 order_number,
115 tool_output_id
116)
117VALUES ($1, $2, $3, $4, $5, $6, $7)
118RETURNING
119 id,
120 created_at,
121 updated_at,
122 deleted_at,
123 conversation_id,
124 message,
125 message_role as "message_role: MessageRole",
126 message_is_complete,
127 used_tokens,
128 order_number,
129 tool_output_id
130 "#,
131 input.conversation_id,
132 input.message,
133 input.message_role as MessageRole,
134 input.message_is_complete,
135 input.used_tokens,
136 input.order_number,
137 None::<Uuid>,
138 )
139 .fetch_one(&mut *tx)
140 .await?;
141
142 let (tool_output_id, tool_output) = match msg.message_role {
143 MessageRole::Assistant => {
144 if msg.message.is_some() {
145 (None, None)
146 } else if !input.tool_call_fields.is_empty() {
147 chatbot_conversation_message_tool_calls::insert_batch(
148 &mut tx,
149 input.tool_call_fields.to_owned(),
150 msg.id,
151 )
152 .await?;
153 (None, None)
154 } else {
155 return ModelResult::Err(ModelError::new(
156 ModelErrorType::InvalidRequest,
157 "A chatbot conversation message with role 'assistant' has to have either a message or tool calls",
158 None,
159 ));
160 }
161 }
162 MessageRole::Tool => {
163 if let Some(output) = input.tool_output {
164 let o_res =
165 chatbot_conversation_message_tool_outputs::insert(&mut tx, output, msg.id)
166 .await?;
167 (Some(o_res.id), Some(o_res))
168 } else {
169 return ModelResult::Err(ModelError::new(
170 ModelErrorType::InvalidRequest,
171 "A chatbot conversation message with role 'tool' must have tool output",
172 None,
173 ));
174 }
175 }
176 MessageRole::User => (None, None),
177 MessageRole::System => {
178 return ModelResult::Err(ModelError::new(
179 ModelErrorType::InvalidRequest,
180 "Can't save system message to database",
181 None,
182 ));
183 }
184 };
185
186 if tool_output_id.is_some() {
188 sqlx::query_as!(
189 ChatbotConversationMessageRow,
190 r#"
191UPDATE chatbot_conversation_messages
192SET tool_output_id = $1
193WHERE id = $2
194 "#,
195 tool_output_id,
196 msg.id,
197 )
198 .execute(&mut *tx)
199 .await?;
200 }
201
202 let res = ChatbotConversationMessage::from_row(msg, tool_output, input.tool_call_fields);
203 tx.commit().await?;
204 Ok(res)
205}
206
207pub async fn get_by_conversation_id(
208 conn: &mut PgConnection,
209 conversation_id: Uuid,
210) -> ModelResult<Vec<ChatbotConversationMessage>> {
211 let mut tx = conn.begin().await?;
212 let mut msgs: Vec<ChatbotConversationMessageRow> = sqlx::query_as!(
213 ChatbotConversationMessageRow,
214 r#"
215SELECT
216 id,
217 created_at,
218 updated_at,
219 deleted_at,
220 conversation_id,
221 message,
222 message_role as "message_role: MessageRole",
223 message_is_complete,
224 used_tokens,
225 order_number,
226 tool_output_id
227FROM chatbot_conversation_messages
228WHERE conversation_id = $1
229AND deleted_at IS NULL
230 "#,
231 conversation_id
232 )
233 .fetch_all(&mut *tx)
234 .await?;
235 msgs.sort_by(|a, b| a.order_number.cmp(&b.order_number));
237 let mut res = vec![];
238 for m in msgs {
239 let msg = message_row_to_message(&mut tx, m).await?;
240 res.push(msg);
241 }
242 tx.commit().await?;
243 Ok(res)
244}
245
246pub async fn update(
247 conn: &mut PgConnection,
248 id: Uuid,
249 message: &str,
250 message_is_complete: bool,
251 used_tokens: i32,
252) -> ModelResult<ChatbotConversationMessage> {
253 let mut tx = conn.begin().await?;
254 let row = sqlx::query_as!(
255 ChatbotConversationMessageRow,
256 r#"
257UPDATE chatbot_conversation_messages
258SET message = $2, message_is_complete = $3, used_tokens = $4
259WHERE id = $1
260RETURNING
261 id,
262 created_at,
263 updated_at,
264 deleted_at,
265 conversation_id,
266 message,
267 message_role as "message_role: MessageRole",
268 message_is_complete,
269 used_tokens,
270 order_number,
271 tool_output_id
272 "#,
273 id,
274 Some(message),
275 message_is_complete,
276 used_tokens
277 )
278 .fetch_one(&mut *tx)
279 .await?;
280
281 let res = message_row_to_message(&mut tx, row).await?;
282 tx.commit().await?;
283 Ok(res)
284}
285
286pub async fn delete(conn: &mut PgConnection, id: Uuid) -> ModelResult<ChatbotConversationMessage> {
287 let mut tx = conn.begin().await?;
288
289 let row = sqlx::query_as!(
290 ChatbotConversationMessageRow,
291 r#"
292UPDATE chatbot_conversation_messages
293SET deleted_at = NOW()
294WHERE id = $1
295RETURNING
296 id,
297 created_at,
298 updated_at,
299 deleted_at,
300 conversation_id,
301 message,
302 message_role as "message_role: MessageRole",
303 message_is_complete,
304 used_tokens,
305 order_number,
306 tool_output_id
307 "#,
308 id
309 )
310 .fetch_one(&mut *tx)
311 .await?;
312
313 if let Some(output_id) = row.tool_output_id {
314 chatbot_conversation_message_tool_outputs::delete(&mut tx, output_id).await?;
315 }
316 chatbot_conversation_message_tool_calls::delete_all_by_message_id(&mut tx, row.id).await?;
317
318 let res = message_row_to_message(&mut tx, row).await?;
319 tx.commit().await?;
320 Ok(res)
321}
322
323pub async fn message_row_to_message(
324 conn: &mut PgConnection,
325 row: ChatbotConversationMessageRow,
326) -> ModelResult<ChatbotConversationMessage> {
327 let o = if let Some(id) = row.tool_output_id {
328 Some(chatbot_conversation_message_tool_outputs::get_by_id(conn, id).await?)
329 } else {
330 None
331 };
332 let c = chatbot_conversation_message_tool_calls::get_by_message_id(conn, row.id).await?;
333 let res = ChatbotConversationMessage::from_row(row, o, c);
334 Ok(res)
335}