1use std::ffi::c_void;
2
3pub trait ProgressCallback: Sized {
4 type SplitProgressType: SplitProgressBuilder;
5
6 fn progress(&mut self, progress: usize, total: usize) -> bool;
10
11 unsafe extern "C" fn cb_progress_callback(
12 ctxt: *mut c_void,
13 progress: usize,
14 total: usize,
15 ) -> bool {
16 let ctxt: &mut Self = &mut *(ctxt as *mut Self);
17 ctxt.progress(progress, total)
18 }
19
20 #[allow(clippy::wrong_self_convention)]
21 unsafe fn into_raw(&mut self) -> *mut c_void {
22 self as *mut Self as *mut c_void
23 }
24
25 fn split(self, subpart_weights: &'static [usize]) -> Self::SplitProgressType;
38}
39
40pub trait SplitProgressBuilder {
41 type Progress<'a>: ProgressCallback
42 where
43 Self: 'a;
44 fn next_subpart(&mut self) -> Option<Self::Progress<'_>>;
45}
46
47impl<F> ProgressCallback for F
48where
49 F: FnMut(usize, usize) -> bool,
50{
51 type SplitProgressType = SplitProgress<F>;
52
53 fn progress(&mut self, progress: usize, total: usize) -> bool {
54 self(progress, total)
55 }
56
57 fn split(self, subpart_weights: &'static [usize]) -> Self::SplitProgressType {
58 SplitProgress::new(self, subpart_weights)
59 }
60}
61
62pub struct NoProgressCallback;
63
64impl ProgressCallback for NoProgressCallback {
65 type SplitProgressType = SplitProgressNop;
66
67 fn progress(&mut self, _progress: usize, _total: usize) -> bool {
68 unreachable!()
69 }
70
71 unsafe extern "C" fn cb_progress_callback(
72 _ctxt: *mut c_void,
73 _progress: usize,
74 _total: usize,
75 ) -> bool {
76 true
77 }
78
79 fn split(self, subpart_weights: &'static [usize]) -> Self::SplitProgressType {
80 SplitProgressNop(subpart_weights.len())
81 }
82}
83
84pub struct SplitProgressNop(usize);
85
86impl SplitProgressBuilder for SplitProgressNop {
87 type Progress<'a> = NoProgressCallback;
88
89 fn next_subpart(&mut self) -> Option<Self::Progress<'_>> {
90 if self.0 == 0 {
91 return None;
92 }
93 self.0 -= 1;
94 Some(NoProgressCallback)
95 }
96}
97
98pub struct SplitProgress<P> {
99 callback: P,
100 subpart_weights: &'static [usize],
101 total: usize,
102 progress: usize,
103}
104
105impl<P: ProgressCallback> SplitProgress<P> {
106 pub fn new(callback: P, subpart_weights: &'static [usize]) -> Self {
107 let total = subpart_weights.iter().sum();
108 Self {
109 callback,
110 subpart_weights,
111 total,
112 progress: 0,
113 }
114 }
115
116 pub fn next_subpart(&mut self) -> Option<SplitProgressInstance<'_, P>> {
117 if self.subpart_weights.is_empty() {
118 return None;
119 }
120 Some(SplitProgressInstance { progress: self })
121 }
122}
123
124impl<P: ProgressCallback> SplitProgressBuilder for SplitProgress<P> {
125 type Progress<'a>
126 = SplitProgressInstance<'a, P>
127 where
128 Self: 'a;
129 fn next_subpart(&mut self) -> Option<Self::Progress<'_>> {
130 self.next_subpart()
131 }
132}
133
134pub struct SplitProgressInstance<'a, P: ProgressCallback> {
135 progress: &'a mut SplitProgress<P>,
136}
137
138impl<P: ProgressCallback> Drop for SplitProgressInstance<'_, P> {
139 fn drop(&mut self) {
140 self.progress.progress += self.progress.subpart_weights[0];
141 self.progress.subpart_weights = &self.progress.subpart_weights[1..];
142 }
143}
144
145impl<P: ProgressCallback> ProgressCallback for SplitProgressInstance<'_, P> {
146 type SplitProgressType = SplitProgress<Self>;
147
148 fn progress(&mut self, progress: usize, total: usize) -> bool {
149 let subpart_progress = (self.progress.subpart_weights[0] * progress) / total;
150 let progress = self.progress.progress + subpart_progress;
151 self.progress
152 .callback
153 .progress(progress, self.progress.total)
154 }
155
156 fn split(self, subpart_weights: &'static [usize]) -> Self::SplitProgressType {
157 SplitProgress::new(self, subpart_weights)
158 }
159}
160
161#[cfg(test)]
162mod test {
163 use std::cell::Cell;
164
165 use super::*;
166
167 #[test]
168 fn progress_simple() {
169 let progress = Cell::new(0);
170 let mut callback = |p, _| {
171 progress.set(p);
172 true
173 };
174 callback.progress(0, 100);
175 assert_eq!(progress.get(), 0);
176 callback.progress(1, 100);
177 assert_eq!(progress.get(), 1);
178 callback.progress(50, 100);
179 assert_eq!(progress.get(), 50);
180 callback.progress(99, 100);
181 assert_eq!(progress.get(), 99);
182 callback.progress(100, 100);
183 assert_eq!(progress.get(), 100);
184 }
185
186 #[test]
187 fn progress_simple_split() {
188 let progress = Cell::new(0);
189 let callback = |p, _| {
190 progress.set(p);
191 true
192 };
193 let mut split = callback.split(&[25, 50, 25]);
194 let mut split_instance = split.next_subpart().unwrap();
196 split_instance.progress(0, 100);
197 assert_eq!(progress.get(), 0);
198 split_instance.progress(100, 100);
199 assert_eq!(progress.get(), 25);
200 drop(split_instance);
201
202 let mut split_instance = split.next_subpart().unwrap();
204 split_instance.progress(0, 100);
205 assert_eq!(progress.get(), 25);
206 split_instance.progress(25, 100);
207 assert!((36..=37).contains(&progress.get()));
211 split_instance.progress(50, 100);
212 assert_eq!(progress.get(), 50);
213 split_instance.progress(100, 100);
214 assert_eq!(progress.get(), 75);
215 drop(split_instance);
216
217 let mut split_instance = split.next_subpart().unwrap();
219 split_instance.progress(0, 100);
220 assert_eq!(progress.get(), 75);
221 split_instance.progress(100, 100);
222 assert_eq!(progress.get(), 100);
223 drop(split_instance);
224
225 assert!(split.next_subpart().is_none());
226 }
227
228 #[test]
229 fn progress_recursive_split() {
230 let progress = Cell::new(0);
231 let callback = |p, _| {
232 progress.set(p);
233 true
234 };
235 let mut split = callback.split(&[25, 50, 25]);
236 let mut split_instance = split.next_subpart().unwrap();
238 split_instance.progress(0, 100);
239 assert_eq!(progress.get(), 0);
240 split_instance.progress(100, 100);
241 assert_eq!(progress.get(), 25);
242 drop(split_instance);
243
244 {
246 let split_instance = split.next_subpart().unwrap();
247 let mut sub_split = split_instance.split(&[50, 50]);
248 let mut sub_split_instance = sub_split.next_subpart().unwrap();
250 sub_split_instance.progress(0, 100);
251 assert_eq!(progress.get(), 25);
252 sub_split_instance.progress(100, 100);
253 assert_eq!(progress.get(), 50);
254 drop(sub_split_instance);
255
256 let mut sub_split_instance = sub_split.next_subpart().unwrap();
258 sub_split_instance.progress(0, 100);
259 assert_eq!(progress.get(), 50);
260 sub_split_instance.progress(100, 100);
261 assert_eq!(progress.get(), 75);
262 drop(sub_split_instance);
263 }
264
265 let mut split_instance = split.next_subpart().unwrap();
267 split_instance.progress(0, 100);
268 assert_eq!(progress.get(), 75);
269 split_instance.progress(100, 100);
270 assert_eq!(progress.get(), 100);
271 drop(split_instance);
272
273 assert!(split.next_subpart().is_none());
274 }
275}