Skip to main content

headless_lms_chatbot/
azure_search_index.rs

1use secrecy::ExposeSecret;
2
3use crate::prelude::*;
4use headless_lms_base::config::ApplicationConfiguration;
5use headless_lms_utils::http::REQWEST_CLIENT;
6
7const API_VERSION: &str = "2024-07-01";
8
9#[derive(Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct NewIndex {
12    name: String,
13    fields: Vec<Field>,
14    scoring_profiles: Vec<ScoringProfile>,
15    default_scoring_profile: Option<String>,
16    suggesters: Vec<Suggester>,
17    analyzers: Vec<Analyzer>,
18    tokenizers: Vec<serde_json::Value>,
19    token_filters: Vec<serde_json::Value>,
20    char_filters: Vec<serde_json::Value>,
21    cors_options: CorsOptions,
22    encryption_key: Option<EncryptionKey>,
23    similarity: Similarity,
24    semantic: Semantic,
25    vector_search: VectorSearch,
26}
27
28#[derive(Serialize, Deserialize)]
29#[serde(rename_all = "camelCase")]
30pub struct Analyzer {
31    name: String,
32    #[serde(rename = "@odata.type")]
33    odata_type: String,
34    char_filters: Vec<String>,
35    tokenizer: String,
36}
37
38#[derive(Serialize, Deserialize)]
39#[serde(rename_all = "camelCase")]
40pub struct CorsOptions {
41    allowed_origins: Vec<String>,
42    max_age_in_seconds: i64,
43}
44
45#[derive(Serialize, Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct EncryptionKey {
48    key_vault_key_name: String,
49    key_vault_key_version: String,
50    key_vault_uri: String,
51    access_credentials: AccessCredentials,
52}
53
54#[derive(Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct AccessCredentials {
57    application_id: String,
58    application_secret: String,
59}
60
61#[derive(Serialize, Deserialize)]
62#[serde(rename_all = "camelCase")]
63pub struct Field {
64    name: String,
65    #[serde(rename = "type")]
66    field_type: String,
67    key: Option<bool>,
68    searchable: Option<bool>,
69    filterable: Option<bool>,
70    sortable: Option<bool>,
71    facetable: Option<bool>,
72    retrievable: Option<bool>,
73    index_analyzer: Option<String>,
74    search_analyzer: Option<String>,
75    analyzer: Option<String>,
76    synonym_maps: Option<Vec<String>>,
77    dimensions: Option<i64>,
78    vector_search_profile: Option<String>,
79    stored: Option<bool>,
80    vector_encoding: Option<serde_json::Value>,
81}
82
83#[derive(Serialize, Deserialize)]
84pub struct ScoringProfile {
85    name: String,
86    text: Text,
87    functions: Vec<Function>,
88}
89
90#[derive(Serialize, Deserialize)]
91#[serde(rename_all = "camelCase")]
92pub struct Function {
93    #[serde(rename = "type")]
94    function_type: String,
95    boost: i64,
96    field_name: String,
97    interpolation: String,
98    distance: Distance,
99}
100
101#[derive(Serialize, Deserialize)]
102#[serde(rename_all = "camelCase")]
103pub struct Distance {
104    reference_point_parameter: String,
105    boosting_distance: i64,
106}
107
108#[derive(Serialize, Deserialize)]
109pub struct Text {
110    weights: Weights,
111}
112
113#[derive(Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct Weights {
116    hotel_name: i64,
117}
118
119#[derive(Serialize, Deserialize)]
120#[serde(rename_all = "camelCase")]
121pub struct Semantic {
122    default_configuration: String,
123    configurations: Vec<SemanticConfiguration>,
124}
125
126#[derive(Serialize, Deserialize)]
127#[serde(rename_all = "camelCase")]
128pub struct SemanticConfiguration {
129    name: String,
130    prioritized_fields: SemanticConfigurationPrioritizedFields,
131}
132
133#[derive(Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct SemanticConfigurationPrioritizedFields {
136    title_field: FieldDescriptor,
137    prioritized_content_fields: Vec<FieldDescriptor>,
138    prioritized_keywords_fields: Vec<FieldDescriptor>,
139}
140
141#[derive(Serialize, Deserialize)]
142#[serde(rename_all = "camelCase")]
143pub struct FieldDescriptor {
144    field_name: String,
145}
146
147#[derive(Serialize, Deserialize)]
148pub struct Similarity {
149    #[serde(rename = "@odata.type")]
150    odata_type: String,
151    b: Option<f64>,
152    k1: Option<f64>,
153}
154
155#[derive(Serialize, Deserialize)]
156#[serde(rename_all = "camelCase")]
157pub struct Suggester {
158    name: String,
159    search_mode: String,
160    source_fields: Vec<String>,
161}
162
163#[derive(Serialize, Deserialize)]
164pub struct VectorSearch {
165    profiles: Vec<Profile>,
166    algorithms: Vec<Algorithm>,
167    compressions: Vec<Compression>,
168    vectorizers: Vec<Vectorizer>,
169}
170
171#[derive(Serialize, Deserialize)]
172#[serde(rename_all = "camelCase")]
173pub struct Vectorizer {
174    name: String,
175    kind: String,
176    #[serde(rename = "azureOpenAIParameters")]
177    azure_open_ai_parameters: AzureOpenAiParameters,
178    custom_web_api_parameters: Option<serde_json::Value>,
179}
180
181#[derive(Serialize, Deserialize)]
182#[serde(rename_all = "camelCase")]
183pub struct AzureOpenAiParameters {
184    resource_uri: String,
185    deployment_id: String,
186    api_key: String,
187    model_name: String,
188    auth_identity: Option<serde_json::Value>,
189}
190
191#[derive(Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct Algorithm {
194    name: String,
195    kind: String,
196    hnsw_parameters: Option<HnswParameters>,
197    exhaustive_knn_parameters: Option<ExhaustiveKnnParameters>,
198}
199
200#[derive(Serialize, Deserialize)]
201pub struct ExhaustiveKnnParameters {
202    metric: String,
203}
204
205#[derive(Serialize, Deserialize)]
206#[serde(rename_all = "camelCase")]
207pub struct HnswParameters {
208    m: i64,
209    metric: String,
210    ef_construction: i64,
211    ef_search: i64,
212}
213
214#[derive(Serialize, Deserialize)]
215#[serde(rename_all = "camelCase")]
216pub struct Compression {
217    name: String,
218    kind: String,
219    scalar_quantization_parameters: Option<ScalarQuantizationParameters>,
220    rerank_with_original_vectors: bool,
221    default_oversampling: i64,
222}
223
224#[derive(Serialize, Deserialize)]
225#[serde(rename_all = "camelCase")]
226pub struct ScalarQuantizationParameters {
227    quantized_data_type: String,
228}
229
230#[derive(Serialize, Deserialize)]
231pub struct Profile {
232    name: String,
233    algorithm: String,
234    compression: Option<String>,
235    vectorizer: Option<String>,
236}
237
238pub async fn does_search_index_exist(
239    index_name: &str,
240    app_config: &ApplicationConfiguration,
241) -> anyhow::Result<bool> {
242    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
243        anyhow::anyhow!("Azure configuration is missing from the application configuration")
244    })?;
245
246    let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
247        anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
248    })?;
249
250    let mut url = search_config.search_endpoint.clone();
251    url.set_path(&format!("indexes('{}')", index_name));
252    url.set_query(Some(&format!("api-version={}", API_VERSION)));
253
254    let response = REQWEST_CLIENT
255        .get(url)
256        .header("Content-Type", "application/json")
257        .header("api-key", search_config.search_api_key.expose_secret())
258        .send()
259        .await?;
260
261    if response.status().is_success() {
262        Ok(true)
263    } else if response.status() == 404 {
264        Ok(false)
265    } else {
266        let status = response.status();
267        let error_text = response.text().await?;
268        Err(anyhow::anyhow!(
269            "Error checking if index exists. Status: {}. Error: {}",
270            status,
271            error_text
272        ))
273    }
274}
275
276pub async fn create_search_index(
277    index_name: String,
278    app_config: &ApplicationConfiguration,
279) -> anyhow::Result<()> {
280    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
281        anyhow::anyhow!("Azure configuration is missing from the application configuration")
282    })?;
283
284    let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
285        anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
286    })?;
287
288    let fields = vec![
289        Field {
290            name: "chunk_id".to_string(),
291            field_type: "Edm.String".to_string(),
292            key: Some(true),
293            searchable: Some(true),
294            filterable: Some(true),
295            retrievable: Some(true),
296            stored: Some(true),
297            sortable: Some(false),
298            facetable: Some(false),
299            analyzer: Some("keyword".to_string()),
300            index_analyzer: None,
301            search_analyzer: None,
302            synonym_maps: Some(vec![]),
303            dimensions: None,
304            vector_search_profile: None,
305            vector_encoding: None,
306        },
307        Field {
308            name: "language".to_string(),
309            field_type: "Edm.String".to_string(),
310            key: Some(false),
311            searchable: Some(true),
312            filterable: Some(true),
313            retrievable: Some(true),
314            stored: Some(true),
315            sortable: Some(false),
316            facetable: Some(false),
317            analyzer: None,
318            index_analyzer: None,
319            search_analyzer: None,
320            synonym_maps: Some(vec![]),
321            dimensions: None,
322            vector_search_profile: None,
323            vector_encoding: None,
324        },
325        Field {
326            name: "parent_id".to_string(),
327            field_type: "Edm.String".to_string(),
328            key: Some(false),
329            searchable: Some(true),
330            filterable: Some(true),
331            retrievable: Some(true),
332            stored: Some(true),
333            sortable: Some(true),
334            facetable: Some(true),
335            analyzer: None,
336            index_analyzer: None,
337            search_analyzer: None,
338            synonym_maps: Some(vec![]),
339            dimensions: None,
340            vector_search_profile: None,
341            vector_encoding: None,
342        },
343        Field {
344            name: "chunk".to_string(),
345            field_type: "Edm.String".to_string(),
346            key: Some(false),
347            searchable: Some(true),
348            filterable: Some(false),
349            retrievable: Some(true),
350            stored: Some(true),
351            sortable: Some(false),
352            facetable: Some(false),
353            analyzer: Some("keyword".to_string()),
354            index_analyzer: None,
355            search_analyzer: None,
356            synonym_maps: Some(vec![]),
357            dimensions: None,
358            vector_search_profile: None,
359            vector_encoding: None,
360        },
361        Field {
362            name: "title".to_string(),
363            field_type: "Edm.String".to_string(),
364            key: Some(false),
365            searchable: Some(true),
366            filterable: Some(true),
367            retrievable: Some(true),
368            stored: Some(true),
369            sortable: Some(false),
370            facetable: Some(false),
371            analyzer: None,
372            index_analyzer: None,
373            search_analyzer: None,
374            synonym_maps: Some(vec![]),
375            dimensions: None,
376            vector_search_profile: None,
377            vector_encoding: None,
378        },
379        Field {
380            name: "url".to_string(),
381            field_type: "Edm.String".to_string(),
382            key: Some(false),
383            searchable: Some(false),
384            filterable: Some(true),
385            retrievable: Some(true),
386            stored: Some(true),
387            sortable: Some(false),
388            facetable: Some(false),
389            analyzer: None,
390            index_analyzer: None,
391            search_analyzer: None,
392            synonym_maps: Some(vec![]),
393            dimensions: None,
394            vector_search_profile: None,
395            vector_encoding: None,
396        },
397        Field {
398            name: "course_id".to_string(),
399            field_type: "Edm.String".to_string(),
400            key: Some(false),
401            searchable: Some(false),
402            filterable: Some(true),
403            retrievable: Some(true),
404            stored: Some(true),
405            sortable: Some(false),
406            facetable: Some(false),
407            analyzer: None,
408            index_analyzer: None,
409            search_analyzer: None,
410            synonym_maps: Some(vec![]),
411            dimensions: None,
412            vector_search_profile: None,
413            vector_encoding: None,
414        },
415        Field {
416            name: "text_vector".to_string(),
417            field_type: "Collection(Edm.Single)".to_string(),
418            key: Some(false),
419            searchable: Some(true),
420            filterable: Some(false),
421            retrievable: Some(true),
422            stored: Some(true),
423            sortable: Some(false),
424            facetable: Some(false),
425            analyzer: None,
426            index_analyzer: None,
427            search_analyzer: None,
428            synonym_maps: Some(vec![]),
429            dimensions: Some(1536),
430            vector_search_profile: Some(format!("{}-azureOpenAi-text-profile", index_name)),
431            vector_encoding: None,
432        },
433        Field {
434            name: "filepath".to_string(),
435            field_type: "Edm.String".to_string(),
436            key: Some(false),
437            searchable: Some(true),
438            filterable: Some(true),
439            retrievable: Some(true),
440            stored: Some(true),
441            sortable: Some(false),
442            facetable: Some(false),
443            analyzer: None,
444            index_analyzer: None,
445            search_analyzer: None,
446            synonym_maps: Some(vec![]),
447            dimensions: None,
448            vector_search_profile: None,
449            vector_encoding: None,
450        },
451        Field {
452            name: "chunk_context".to_string(),
453            field_type: "Edm.String".to_string(),
454            key: Some(false),
455            searchable: Some(true),
456            filterable: Some(false),
457            retrievable: Some(true),
458            stored: Some(true),
459            sortable: Some(false),
460            facetable: Some(false),
461            analyzer: None,
462            index_analyzer: None,
463            search_analyzer: None,
464            synonym_maps: Some(vec![]),
465            dimensions: None,
466            vector_search_profile: None,
467            vector_encoding: None,
468        },
469    ];
470
471    let index = NewIndex {
472        name: index_name.clone(),
473        fields,
474        scoring_profiles: vec![],
475        default_scoring_profile: None,
476        suggesters: vec![],
477        analyzers: vec![],
478        tokenizers: vec![],
479        token_filters: vec![],
480        char_filters: vec![],
481        cors_options: CorsOptions {
482            allowed_origins: vec!["*".to_string()],
483            max_age_in_seconds: 300,
484        },
485        encryption_key: None,
486        similarity: Similarity {
487            odata_type: "#Microsoft.Azure.Search.BM25Similarity".to_string(),
488            b: None,
489            k1: None,
490        },
491        semantic: Semantic {
492            default_configuration: format!("{}-semantic-configuration", index_name),
493            configurations: vec![SemanticConfiguration {
494                name: format!("{}-semantic-configuration", index_name),
495                prioritized_fields: SemanticConfigurationPrioritizedFields {
496                    title_field: FieldDescriptor {
497                        field_name: "title".to_string(),
498                    },
499                    prioritized_content_fields: vec![FieldDescriptor {
500                        field_name: "chunk".to_string(),
501                    }],
502                    prioritized_keywords_fields: vec![],
503                },
504            }],
505        },
506        vector_search: VectorSearch {
507            profiles: vec![Profile {
508                name: format!("{}-azureOpenAi-text-profile", index_name),
509                algorithm: format!("{}-algorithm", index_name),
510                vectorizer: Some(format!("{}-azureOpenAi-text-vectorizer", index_name)),
511                compression: None,
512            }],
513            algorithms: vec![Algorithm {
514                name: format!("{}-algorithm", index_name),
515                kind: "hnsw".to_string(),
516                hnsw_parameters: Some(HnswParameters {
517                    m: 4,
518                    metric: "cosine".to_string(),
519                    ef_construction: 400,
520                    ef_search: 500,
521                }),
522                exhaustive_knn_parameters: None,
523            }],
524            vectorizers: vec![Vectorizer {
525                name: format!("{}-azureOpenAi-text-vectorizer", index_name),
526                kind: "azureOpenAI".to_string(),
527                azure_open_ai_parameters: AzureOpenAiParameters {
528                    resource_uri: search_config.vectorizer_resource_uri.clone(),
529                    deployment_id: search_config.vectorizer_deployment_id.clone(),
530                    api_key: search_config.vectorizer_api_key.expose_secret().to_string(),
531                    model_name: search_config.vectorizer_model_name.clone(),
532                    auth_identity: None,
533                },
534                custom_web_api_parameters: None,
535            }],
536            compressions: vec![],
537        },
538    };
539
540    let index_json = serde_json::to_string(&index)?;
541
542    let mut url = search_config.search_endpoint.clone();
543    url.set_path("/indexes");
544    url.set_query(Some(&format!("api-version={}", API_VERSION)));
545
546    let response = REQWEST_CLIENT
547        .post(url)
548        .header("Content-Type", "application/json")
549        .header("api-key", search_config.search_api_key.expose_secret())
550        .body(index_json)
551        .send()
552        .await?;
553
554    // Check for a successful response
555    if response.status().is_success() {
556        println!("Index created successfully: {}", index_name);
557        Ok(())
558    } else {
559        let status = response.status();
560        let error_text = response.text().await?;
561        Err(anyhow::anyhow!(
562            "Failed to create index. Status: {}. Error: {}",
563            status,
564            error_text
565        ))
566    }
567}
568
569#[derive(Serialize, Deserialize)]
570pub struct IndexAction<T> {
571    #[serde(rename = "@search.action")]
572    pub search_action: String,
573    pub document: T,
574}
575
576#[derive(Serialize, Deserialize)]
577pub struct IndexBatch<T> {
578    pub value: Vec<IndexAction<T>>,
579}
580
581pub async fn add_documents_to_index<T>(
582    index_name: &str,
583    documents: Vec<T>,
584    app_config: &ApplicationConfiguration,
585) -> anyhow::Result<()>
586where
587    T: Serialize,
588{
589    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
590        anyhow::anyhow!("Azure configuration is missing from the application configuration")
591    })?;
592
593    let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
594        anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
595    })?;
596
597    let mut url = search_config.search_endpoint.clone();
598    url.set_path(&format!("indexes('{}')/docs/index", index_name));
599    url.set_query(Some(&format!("api-version={}", API_VERSION)));
600
601    let index_actions: anyhow::Result<Vec<IndexAction<String>>> = documents
602        .into_iter()
603        .map(|doc| {
604            serde_json::to_string(&doc)
605                .map(|document| IndexAction {
606                    search_action: "upload".to_string(),
607                    document,
608                })
609                .map_err(|e| anyhow::anyhow!("Failed to serialize document: {}", e))
610        })
611        .collect();
612    let index_actions = index_actions?;
613
614    let batch = IndexBatch {
615        value: index_actions,
616    };
617
618    let batch_json = serde_json::to_string(&batch)?;
619
620    let response = REQWEST_CLIENT
621        .post(url)
622        .header("Content-Type", "application/json")
623        .header("api-key", search_config.search_api_key.expose_secret())
624        .body(batch_json)
625        .send()
626        .await?;
627
628    if response.status().is_success() {
629        println!("Documents added successfully to index: {}", index_name);
630        Ok(())
631    } else {
632        let status = response.status();
633        let error_text = response.text().await?;
634        Err(anyhow::anyhow!(
635            "Failed to add documents to index. Status: {}. Error: {}",
636            status,
637            error_text
638        ))
639    }
640}