1#![allow(clippy::exhaustive_structs, clippy::exhaustive_enums)]
9
10use icu_provider::prelude::*;
11use potential_utf::PotentialUtf8;
12use zerovec::{ZeroMap, ZeroVec};
13
14macro_rules! lstm_matrix {
17 ($name:ident, $generic:literal) => {
18 #[derive(PartialEq, Debug, Clone, zerofrom::ZeroFrom, yoke::Yokeable)]
26 #[cfg_attr(feature = "datagen", derive(serde::Serialize))]
27 pub struct $name<'data> {
28 #[allow(missing_docs)]
30 pub(crate) dims: [u16; $generic],
31 #[allow(missing_docs)]
32 pub(crate) data: ZeroVec<'data, f32>,
33 }
34
35 impl<'data> $name<'data> {
36 #[cfg(any(feature = "serde", feature = "datagen"))]
37 pub fn from_parts(
39 dims: [u16; $generic],
40 data: ZeroVec<'data, f32>,
41 ) -> Result<Self, DataError> {
42 if dims.iter().map(|&i| i as usize).product::<usize>() != data.len() {
43 Err(DataError::custom("Dimension mismatch"))
44 } else {
45 Ok(Self { dims, data })
46 }
47 }
48
49 #[doc(hidden)] pub const fn from_parts_unchecked(
51 dims: [u16; $generic],
52 data: ZeroVec<'data, f32>,
53 ) -> Self {
54 Self { dims, data }
55 }
56 }
57
58 #[cfg(feature = "serde")]
59 impl<'de: 'data, 'data> serde::Deserialize<'de> for $name<'data> {
60 fn deserialize<S>(deserializer: S) -> Result<Self, S::Error>
61 where
62 S: serde::de::Deserializer<'de>,
63 {
64 #[derive(serde::Deserialize)]
65 struct Raw<'data> {
66 dims: [u16; $generic],
67 #[serde(borrow)]
68 data: ZeroVec<'data, f32>,
69 }
70
71 let raw = Raw::deserialize(deserializer)?;
72
73 use serde::de::Error;
74 Self::from_parts(raw.dims, raw.data)
75 .map_err(|_| S::Error::custom("Dimension mismatch"))
76 }
77 }
78
79 #[cfg(feature = "datagen")]
80 impl databake::Bake for $name<'_> {
81 fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream {
82 let dims = self.dims.bake(env);
83 let data = self.data.bake(env);
84 databake::quote! {
85 icu_segmenter::provider::$name::from_parts_unchecked(#dims, #data)
86 }
87 }
88 }
89
90 #[cfg(feature = "datagen")]
91 impl databake::BakeSize for $name<'_> {
92 fn borrows_size(&self) -> usize {
93 self.data.borrows_size()
94 }
95 }
96 };
97}
98
99lstm_matrix!(LstmMatrix1, 1);
100lstm_matrix!(LstmMatrix2, 2);
101lstm_matrix!(LstmMatrix3, 3);
102
103#[derive(PartialEq, Debug, Clone, Copy)]
104#[cfg_attr(feature = "datagen", derive(serde::Serialize, databake::Bake))]
105#[cfg_attr(feature = "datagen", databake(path = icu_segmenter::provider))]
106#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
107pub enum ModelType {
115 Codepoints,
117 GraphemeClusters,
119}
120
121#[derive(PartialEq, Debug, Clone, yoke::Yokeable, zerofrom::ZeroFrom)]
129#[cfg_attr(feature = "datagen", derive(serde::Serialize))]
130#[yoke(prove_covariance_manually)]
131pub struct LstmDataFloat32<'data> {
132 pub(crate) model: ModelType,
134 pub(crate) dic: ZeroMap<'data, PotentialUtf8, u16>,
136 pub(crate) embedding: LstmMatrix2<'data>,
138 pub(crate) fw_w: LstmMatrix3<'data>,
140 pub(crate) fw_u: LstmMatrix3<'data>,
142 pub(crate) fw_b: LstmMatrix2<'data>,
144 pub(crate) bw_w: LstmMatrix3<'data>,
146 pub(crate) bw_u: LstmMatrix3<'data>,
148 pub(crate) bw_b: LstmMatrix2<'data>,
150 pub(crate) time_w: LstmMatrix3<'data>,
152 pub(crate) time_b: LstmMatrix1<'data>,
154}
155
156impl<'data> LstmDataFloat32<'data> {
157 #[doc(hidden)] #[allow(clippy::too_many_arguments)] pub const fn from_parts_unchecked(
160 model: ModelType,
161 dic: ZeroMap<'data, PotentialUtf8, u16>,
162 embedding: LstmMatrix2<'data>,
163 fw_w: LstmMatrix3<'data>,
164 fw_u: LstmMatrix3<'data>,
165 fw_b: LstmMatrix2<'data>,
166 bw_w: LstmMatrix3<'data>,
167 bw_u: LstmMatrix3<'data>,
168 bw_b: LstmMatrix2<'data>,
169 time_w: LstmMatrix3<'data>,
170 time_b: LstmMatrix1<'data>,
171 ) -> Self {
172 Self {
173 model,
174 dic,
175 embedding,
176 fw_w,
177 fw_u,
178 fw_b,
179 bw_w,
180 bw_u,
181 bw_b,
182 time_w,
183 time_b,
184 }
185 }
186
187 #[cfg(any(feature = "serde", feature = "datagen"))]
188 #[allow(clippy::too_many_arguments)] pub fn try_from_parts(
191 model: ModelType,
192 dic: ZeroMap<'data, PotentialUtf8, u16>,
193 embedding: LstmMatrix2<'data>,
194 fw_w: LstmMatrix3<'data>,
195 fw_u: LstmMatrix3<'data>,
196 fw_b: LstmMatrix2<'data>,
197 bw_w: LstmMatrix3<'data>,
198 bw_u: LstmMatrix3<'data>,
199 bw_b: LstmMatrix2<'data>,
200 time_w: LstmMatrix3<'data>,
201 time_b: LstmMatrix1<'data>,
202 ) -> Result<Self, DataError> {
203 let dic_len = u16::try_from(dic.len())
204 .map_err(|_| DataError::custom("Dictionary does not fit in u16"))?;
205
206 let num_classes = embedding.dims[0];
207 let embedd_dim = embedding.dims[1];
208 let hunits = fw_u.dims[2];
209 if num_classes - 1 != dic_len
210 || fw_w.dims != [4, hunits, embedd_dim]
211 || fw_u.dims != [4, hunits, hunits]
212 || fw_b.dims != [4, hunits]
213 || bw_w.dims != [4, hunits, embedd_dim]
214 || bw_u.dims != [4, hunits, hunits]
215 || bw_b.dims != [4, hunits]
216 || time_w.dims != [2, 4, hunits]
217 || time_b.dims != [4]
218 {
219 return Err(DataError::custom("LSTM dimension mismatch"));
220 }
221
222 #[cfg(debug_assertions)]
223 if !dic.iter_copied_values().all(|(_, g)| g < dic_len) {
224 return Err(DataError::custom("Invalid cluster id"));
225 }
226
227 Ok(Self {
228 model,
229 dic,
230 embedding,
231 fw_w,
232 fw_u,
233 fw_b,
234 bw_w,
235 bw_u,
236 bw_b,
237 time_w,
238 time_b,
239 })
240 }
241}
242
243#[cfg(feature = "serde")]
244impl<'de: 'data, 'data> serde::Deserialize<'de> for LstmDataFloat32<'data> {
245 fn deserialize<S>(deserializer: S) -> Result<Self, S::Error>
246 where
247 S: serde::de::Deserializer<'de>,
248 {
249 #[derive(serde::Deserialize)]
250 struct Raw<'data> {
251 model: ModelType,
252 #[cfg_attr(feature = "serde", serde(borrow))]
253 dic: ZeroMap<'data, PotentialUtf8, u16>,
254 #[cfg_attr(feature = "serde", serde(borrow))]
255 embedding: LstmMatrix2<'data>,
256 #[cfg_attr(feature = "serde", serde(borrow))]
257 fw_w: LstmMatrix3<'data>,
258 #[cfg_attr(feature = "serde", serde(borrow))]
259 fw_u: LstmMatrix3<'data>,
260 #[cfg_attr(feature = "serde", serde(borrow))]
261 fw_b: LstmMatrix2<'data>,
262 #[cfg_attr(feature = "serde", serde(borrow))]
263 bw_w: LstmMatrix3<'data>,
264 #[cfg_attr(feature = "serde", serde(borrow))]
265 bw_u: LstmMatrix3<'data>,
266 #[cfg_attr(feature = "serde", serde(borrow))]
267 bw_b: LstmMatrix2<'data>,
268 #[cfg_attr(feature = "serde", serde(borrow))]
269 time_w: LstmMatrix3<'data>,
270 #[cfg_attr(feature = "serde", serde(borrow))]
271 time_b: LstmMatrix1<'data>,
272 }
273
274 let raw = Raw::deserialize(deserializer)?;
275
276 use serde::de::Error;
277 Self::try_from_parts(
278 raw.model,
279 raw.dic,
280 raw.embedding,
281 raw.fw_w,
282 raw.fw_u,
283 raw.fw_b,
284 raw.bw_w,
285 raw.bw_u,
286 raw.bw_b,
287 raw.time_w,
288 raw.time_b,
289 )
290 .map_err(|_| S::Error::custom("Invalid dimensions"))
291 }
292}
293
294#[cfg(feature = "datagen")]
295impl databake::Bake for LstmDataFloat32<'_> {
296 fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream {
297 let model = self.model.bake(env);
298 let dic = self.dic.bake(env);
299 let embedding = self.embedding.bake(env);
300 let fw_w = self.fw_w.bake(env);
301 let fw_u = self.fw_u.bake(env);
302 let fw_b = self.fw_b.bake(env);
303 let bw_w = self.bw_w.bake(env);
304 let bw_u = self.bw_u.bake(env);
305 let bw_b = self.bw_b.bake(env);
306 let time_w = self.time_w.bake(env);
307 let time_b = self.time_b.bake(env);
308 databake::quote! {
309 icu_segmenter::provider::LstmDataFloat32::from_parts_unchecked(
310 #model,
311 #dic,
312 #embedding,
313 #fw_w,
314 #fw_u,
315 #fw_b,
316 #bw_w,
317 #bw_u,
318 #bw_b,
319 #time_w,
320 #time_b,
321 )
322 }
323 }
324}
325
326#[cfg(feature = "datagen")]
327impl databake::BakeSize for LstmDataFloat32<'_> {
328 fn borrows_size(&self) -> usize {
329 self.model.borrows_size()
330 + self.dic.borrows_size()
331 + self.embedding.borrows_size()
332 + self.fw_w.borrows_size()
333 + self.fw_u.borrows_size()
334 + self.fw_b.borrows_size()
335 + self.bw_w.borrows_size()
336 + self.bw_u.borrows_size()
337 + self.bw_b.borrows_size()
338 + self.time_w.borrows_size()
339 + self.time_b.borrows_size()
340 }
341}
342
343#[derive(Debug, PartialEq, Clone, yoke::Yokeable, zerofrom::ZeroFrom)]
359#[cfg_attr(feature = "datagen", derive(serde::Serialize, databake::Bake))]
360#[cfg_attr(feature = "datagen", databake(path = icu_segmenter::provider))]
361#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
362#[yoke(prove_covariance_manually)]
363#[non_exhaustive]
364pub enum LstmData<'data> {
365 Float32(#[cfg_attr(feature = "serde", serde(borrow))] LstmDataFloat32<'data>),
367 }
371
372icu_provider::data_struct!(
373 LstmData<'_>,
374 #[cfg(feature = "datagen")]
375);