Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions dfir_lang/src/graph/ops/fold.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use quote::quote_spanned;

use super::{
DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, Persistence, RANGE_0,
OperatorCategory, OperatorConstraints, OperatorWriteOutput, Persistence, RANGE_0,
RANGE_1, WriteContextArgs,
};

Expand Down Expand Up @@ -46,7 +46,7 @@ pub const FOLD: OperatorConstraints = OperatorConstraints {
flo_type: None,
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| Some(DelayType::Stratum),
input_delaytype_fn: |_| None,
write_fn: |wc @ &WriteContextArgs {
root,
op_span,
Expand Down Expand Up @@ -126,15 +126,44 @@ pub const FOLD: OperatorConstraints = OperatorConstraints {
)
);
}
} else {
assert_eq!(0, outputs.len());
} else if outputs.is_empty() {
// Terminal push: fold is a singleton reference target with no downstream.
quote_spanned! {op_span=>
let #ident = #root::dfir_pipes::push::for_each(|#item_ident| {
#assign_accum_ident

#foreach_body
});
}
} else {
let output = &outputs[0];
quote_spanned! {op_span=>
let #ident = {
#[inline(always)]
fn __push_fold<'a, Acc, Item, CombFn, Next>(
acc_ref: &'a mut Acc,
comb_fn: CombFn,
next: Next,
) -> #root::dfir_pipes::push::Accumulate<
#root::dfir_pipes::push::FoldState<&'a mut Acc, CombFn, Acc, Item>,
Next,
>
where
CombFn: ::std::ops::FnMut(&mut Acc, Item),
Next: #root::dfir_pipes::push::Push<&'a mut Acc, ()>,
{
#root::dfir_pipes::push::fold(acc_ref, comb_fn, next)
}
__push_fold(
&mut #singleton_output_ident,
|#accumulator_ident: &mut _, #item_ident| { #foreach_body },
#root::dfir_pipes::push::map(
|__val: &mut _| ::std::clone::Clone::clone(&*__val),
#output,
),
)
};
}
};

Ok(OperatorWriteOutput {
Expand Down
24 changes: 18 additions & 6 deletions dfir_lang/src/graph/ops/fold_keyed.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use quote::{ToTokens, quote_spanned};

use super::{
DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
OperatorWriteOutput, Persistence, RANGE_1, WriteContextArgs,
};

Expand Down Expand Up @@ -79,15 +79,15 @@ pub const FOLD_KEYED: OperatorConstraints = OperatorConstraints {
flo_type: None,
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| Some(DelayType::Stratum),
input_delaytype_fn: |_| None,
write_fn: |wc @ &WriteContextArgs {
op_span,
work_fn_async,
ident,
inputs,
outputs,
is_pull,
root,
op_name,
op_inst:
OperatorInstance {
generics:
Expand All @@ -102,8 +102,6 @@ pub const FOLD_KEYED: OperatorConstraints = OperatorConstraints {
..
},
_| {
assert!(is_pull, "TODO(mingwei): `{}` only supports pull.", op_name);

let persistence = match persistence_args[..] {
[] => Persistence::Tick,
[a] => a,
Expand Down Expand Up @@ -143,7 +141,21 @@ pub const FOLD_KEYED: OperatorConstraints = OperatorConstraints {
let mut #hashtable_ident = &mut #singleton_output_ident;
};

let write_iterator = if Persistence::Mutable == persistence {
let write_iterator = if !is_pull {
assert!(
Persistence::Mutable != persistence,
"fold_keyed::<'mutable> on push side is not supported ('mutable is being removed)"
);
let output = &outputs[0];
quote_spanned! {op_span=>
let #ident = #root::dfir_pipes::push::FoldKeyed::new(
&mut #singleton_output_ident,
#initfn,
#aggfn,
#output,
);
}
} else if Persistence::Mutable == persistence {
quote_spanned! {op_span=>
#assign_hashtable_ident

Expand Down
49 changes: 45 additions & 4 deletions dfir_lang/src/graph/ops/fold_no_replay.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use quote::quote_spanned;

use super::{
DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, Persistence, RANGE_0,
OperatorCategory, OperatorConstraints, OperatorWriteOutput, Persistence, RANGE_0,
RANGE_1, WriteContextArgs,
};

Expand All @@ -28,7 +28,7 @@ pub const FOLD_NO_REPLAY: OperatorConstraints = OperatorConstraints {
flo_type: None,
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| Some(DelayType::Stratum),
input_delaytype_fn: |_| None,
write_fn: |wc @ &WriteContextArgs {
root,
context,
Expand Down Expand Up @@ -117,15 +117,56 @@ pub const FOLD_NO_REPLAY: OperatorConstraints = OperatorConstraints {
)
};
}
} else {
assert_eq!(0, outputs.len());
} else if outputs.is_empty() {
// Terminal push: fold_no_replay is a singleton reference target with no downstream.
quote_spanned! {op_span=>
let #ident = #root::dfir_pipes::push::for_each(|#item_ident| {
#assign_accum_ident

#foreach_body
});
}
} else {
let output = &outputs[0];
let was_updated_ident = wc.make_ident("was_updated");
quote_spanned! {op_span=>
let #was_updated_ident = ::std::cell::Cell::new(false);
let #ident = {
#[inline(always)]
fn __push_fold<'a, Acc, Item, CombFn, Next>(
acc_ref: &'a mut Acc,
comb_fn: CombFn,
next: Next,
) -> #root::dfir_pipes::push::Accumulate<
#root::dfir_pipes::push::FoldState<&'a mut Acc, CombFn, Acc, Item>,
Next,
>
where
CombFn: ::std::ops::FnMut(&mut Acc, Item),
Next: #root::dfir_pipes::push::Push<&'a mut Acc, ()>,
{
#root::dfir_pipes::push::fold(acc_ref, comb_fn, next)
}
__push_fold(
&mut #singleton_output_ident,
|#accumulator_ident: &mut _, #item_ident| {
#was_updated_ident.set(true);
#foreach_body
},
#root::dfir_pipes::push::filter(
{
let __was_updated = &#was_updated_ident;
let __context: &_ = #context;
move |_| __was_updated.get() || __context.current_tick().0 == 0
},
#root::dfir_pipes::push::map(
|__val: &mut _| ::std::clone::Clone::clone(&*__val),
#output,
),
),
Comment thread
MingweiSamuel marked this conversation as resolved.
)
};
}
};

Ok(OperatorWriteOutput {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error[E0271]: type mismatch resolving `<Iter<Drain<'_, '_, &str>> as Pull>::Item == (_, _)`
error[E0271]: type mismatch resolving `<impl Pull<Item = &str, Meta = (), CanPend = <Iter<&mut impl Iterator<Item = &str>> as Pull>::CanPend, CanEnd = <Iter<&mut impl Iterator<Item = &str>> as Pull>::CanEnd> as Pull>::Item == (_, _)`
--> tests/compile-fail/stable/surface_fold_keyed_generics_bad.rs:3:9
|
3 | source_iter(["hello", "world"])
Expand All @@ -14,7 +14,7 @@ note: required by a bound in `check_input`
4 | -> fold_keyed::<'tick, &str, usize>(String::new, |old: &mut _, val| {
| ^^^^^^^^^^ required by this bound in `check_input`

error[E0271]: type mismatch resolving `<Iter<Drain<'_, '_, &str>> as Pull>::Item == (_, _)`
error[E0271]: type mismatch resolving `<impl Pull<Item = &str, Meta = (), CanPend = <Iter<&mut impl Iterator<Item = &str>> as Pull>::CanPend, CanEnd = <Iter<&mut impl Iterator<Item = &str>> as Pull>::CanEnd> as Pull>::Item == (_, _)`
--> tests/compile-fail/stable/surface_fold_keyed_generics_bad.rs:4:16
|
4 | -> fold_keyed::<'tick, &str, usize>(String::new, |old: &mut _, val| {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error[E0271]: type mismatch resolving `<Iter<Drain<'_, '_, {integer}>> as Pull>::Item == (_, _)`
error[E0271]: type mismatch resolving `<impl Pull<Item = {integer}, Meta = (), CanPend = <Iter<&mut impl Iterator<Item = {integer}>> as Pull>::CanPend, CanEnd = <Iter<&mut impl Iterator<Item = {integer}>> as Pull>::CanEnd> as Pull>::Item == (_, _)`
--> tests/compile-fail/stable/surface_reduce_keyed_badtype_int.rs:3:9
|
3 | source_iter(0..1)
Expand All @@ -14,7 +14,7 @@ note: required by a bound in `check_input`
4 | -> fold_keyed(|| 0, |old: &mut u32, val: u32| { *old += val; })
| ^^^^^^^^^^ required by this bound in `check_input`

error[E0271]: type mismatch resolving `<Iter<Drain<'_, '_, {integer}>> as Pull>::Item == (_, _)`
error[E0271]: type mismatch resolving `<impl Pull<Item = {integer}, Meta = (), CanPend = <Iter<&mut impl Iterator<Item = {integer}>> as Pull>::CanPend, CanEnd = <Iter<&mut impl Iterator<Item = {integer}>> as Pull>::CanEnd> as Pull>::Item == (_, _)`
--> tests/compile-fail/stable/surface_reduce_keyed_badtype_int.rs:4:16
|
4 | -> fold_keyed(|| 0, |old: &mut u32, val: u32| { *old += val; })
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error[E0271]: type mismatch resolving `<Iter<Drain<'_, '_, Option<{integer}>>> as Pull>::Item == (_, _)`
error[E0271]: type mismatch resolving `<impl Pull<Item = Option<{integer}>, Meta = (), CanPend = <Iter<&mut impl Iterator<Item = Option<{integer}>>> as Pull>::CanPend, CanEnd = <Iter<&mut impl Iterator<Item = Option<{integer}>>> as Pull>::CanEnd> as Pull>::Item == (_, _)`
--> tests/compile-fail/stable/surface_reduce_keyed_badtype_option.rs:3:9
|
3 | source_iter([ Some(5), None, Some(12) ])
Expand All @@ -14,7 +14,7 @@ note: required by a bound in `check_input`
4 | -> fold_keyed(|| 0, |old: &mut u32, val: u32| { *old += val; })
| ^^^^^^^^^^ required by this bound in `check_input`

error[E0271]: type mismatch resolving `<Iter<Drain<'_, '_, Option<{integer}>>> as Pull>::Item == (_, _)`
error[E0271]: type mismatch resolving `<impl Pull<Item = Option<{integer}>, Meta = (), CanPend = <Iter<&mut impl Iterator<Item = Option<{integer}>>> as Pull>::CanPend, CanEnd = <Iter<&mut impl Iterator<Item = Option<{integer}>>> as Pull>::CanEnd> as Pull>::Item == (_, _)`
--> tests/compile-fail/stable/surface_reduce_keyed_badtype_option.rs:4:16
|
4 | -> fold_keyed(|| 0, |old: &mut u32, val: u32| { *old += val; })
Expand Down
4 changes: 2 additions & 2 deletions dfir_rs/tests/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async fn test_handoff_metrics() {
let mut flow = dfir_rs::dfir_syntax! {
source_iter(0..5)
-> map(|x| x * 2)
-> fold(|| 0, |acc: &mut _, x| { *acc += x; })
-> defer_tick()
-> for_each(|x| { output_send.send(x).unwrap(); });
};

Expand All @@ -99,7 +99,7 @@ async fn test_handoff_metrics() {

// Verify output
let output: Vec<_> = collect_ready_async(&mut output_recv).await;
assert_eq!(output, vec![20]);
assert_eq!(output, vec![0, 2, 4, 6, 8]);
}

#[multiplatform_test(dfir)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
source: dfir_rs/tests/surface_cross_singleton.rs
assertion_line: 58
expression: df.meta_graph().unwrap().to_dot(cfg)
---
digraph {
Expand All @@ -16,7 +15,7 @@ digraph {
n8v1 [label="(n8v1) cross_singleton()", shape=invhouse, fillcolor="#88aaff"]
n9v1 [label="(n9v1) tee()", shape=house, fillcolor="#ffff88"]
n10v1 [label="(n10v1) for_each(|x| egress_tx.send(x).unwrap())", shape=house, fillcolor="#ffff88"]
n11v1 [label="(n11v1) fold(|| 0, |_, _| {})", shape=invhouse, fillcolor="#88aaff"]
n11v1 [label="(n11v1) fold(|| 0, |_, _| {})", shape=house, fillcolor="#ffff88"]
n12v1 [label="(n12v1) cross_singleton()", shape=invhouse, fillcolor="#88aaff"]
n13v1 [label="(n13v1) fold(|| 0, |_, _| {})", shape=invhouse, fillcolor="#88aaff"]
n14v1 [label="(n14v1) flat_map(|_| [])", shape=invhouse, fillcolor="#88aaff"]
Expand All @@ -26,8 +25,6 @@ digraph {
n18v1 [label="(n18v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n19v1 [label="(n19v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n20v1 [label="(n20v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n21v1 [label="(n21v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n22v1 [label="(n22v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n2v1 -> n3v1
n1v1 -> n15v1
n3v1 -> n16v1
Expand All @@ -38,19 +35,17 @@ digraph {
n7v1 -> n18v1
n8v1 -> n9v1
n9v1 -> n10v1
n9v1 -> n19v1
n3v1 -> n20v1
n11v1 -> n21v1
n9v1 -> n11v1
n3v1 -> n19v1
n11v1 -> n20v1
n13v1 -> n14v1
n12v1 -> n22v1
n12v1 -> n13v1
n15v1 -> n2v1 [color=red]
n16v1 -> n8v1 [label="input"]
n17v1 -> n4v1 [color=red]
n18v1 -> n8v1 [label="single", color=red]
n19v1 -> n11v1 [color=red]
n20v1 -> n12v1 [label="input"]
n21v1 -> n12v1 [label="single", color=red]
n22v1 -> n13v1 [color=red]
n19v1 -> n12v1 [label="input"]
n20v1 -> n12v1 [label="single", color=red]
subgraph sg_1v1 {
cluster=true
fillcolor="#dddddd"
Expand Down Expand Up @@ -98,6 +93,11 @@ digraph {
style=filled
label = "sg_4v1"
n10v1
subgraph sg_4v1_var_folded_thing {
cluster=true
label="var folded_thing"
n11v1
}
subgraph sg_4v1_var_join {
cluster=true
label="var join"
Expand All @@ -110,33 +110,16 @@ digraph {
fillcolor="#dddddd"
style=filled
label = "sg_5v1"
subgraph sg_5v1_var_folded_thing {
subgraph sg_5v1_var_deferred_stream {
cluster=true
label="var folded_thing"
n11v1
label="var deferred_stream"
n13v1
n14v1
}
}
subgraph sg_6v1 {
cluster=true
fillcolor="#dddddd"
style=filled
label = "sg_6v1"
subgraph sg_6v1_var_joined_folded {
subgraph sg_5v1_var_joined_folded {
cluster=true
label="var joined_folded"
n12v1
}
}
subgraph sg_7v1 {
cluster=true
fillcolor="#dddddd"
style=filled
label = "sg_7v1"
subgraph sg_7v1_var_deferred_stream {
cluster=true
label="var deferred_stream"
n13v1
n14v1
}
}
}
Loading
Loading