Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better graph selection api #33

Open
jafioti opened this issue Mar 3, 2024 · 1 comment
Open

Better graph selection api #33

jafioti opened this issue Mar 3, 2024 · 1 comment

Comments

@jafioti
Copy link
Owner

jafioti commented Mar 3, 2024

Currently the graph selection api makes it difficult to write selectors for complex patterns like Rope:

let freqs = select_const!(1000000.0_f32.ln(), T)

Selectors should be built similarly to how primgraphs are already built, with a graphtensor-like api (no compile-time shapes though).

@jafioti
Copy link
Owner Author

jafioti commented Mar 4, 2024

Didn't end up with that approach, but the new selector api is much nicer to write complex patterns with. Rope went from

let freqs = select_const!(1000000.0_f32.ln(), T)
            .ptr(&mut theta)
            .edge(
                select_ty!(MetalConstant<T>)
                    .ptr(&mut inv_head_dim)
                    .edge(
                        select_ty!(MetalConstant<T>)
                            .ptr(&mut two)
                            .edge(
                                select_ty!(crate::other::MetalARange<T>)
                                    .ptr(&mut head_dim_arange)
                                    .edge(select_ty!(MetalMul<T>).ptr(&mut mul_2)),
                            )
                            .edge(select_ty!(MetalMul<T>).ptr(&mut head_dim_mul)),
                    )
                    .edge(select_ty!(MetalMul<T>).ptr(&mut theta_mul)),
            )
            .edge(select_ty!(MetalExp<T>).ptr(&mut exp))
            .edge(select_ty!(MetalRecip<T>).ptr(&mut recip));
let seq = select_ty!(MetalConstant<T>).ptr(&mut seq_expr).edge(
    select_ty!(crate::other::MetalARange<T>)
        .ptr(&mut seq_arange)
        .edge(select_ty!(MetalAdd<T>).ptr(&mut seq_add)),
);
let emb = freqs.edge(seq.edge(select_ty!(MetalMul<T>).ptr(&mut freq_seq_mul)));
let split = SelectOp::new()
    .ptr(&mut input)
    .edge(select_ty!(MetalContiguous<T>).ptr(&mut split_contig1));
let x0 = split
    .clone()
    .edge(select_ty!(MetalContiguous<T>).ptr(&mut split_contig2));
let x1 = split.edge(select_ty!(MetalContiguous<T>).ptr(&mut split_contig3));
let x0_sin = emb
    .clone()
    .edge(select_ty!(MetalSin<T>).ptr(&mut sin1))
    .edge(x0.clone().edge(select_ty!(MetalMul<T>).ptr(&mut out_mul1)));
let x0_cos = emb
    .clone()
    .edge(select_ty!(MetalCos<T>).ptr(&mut cos1))
    .edge(x0.edge(select_ty!(MetalMul<T>).ptr(&mut out_mul2)));
let x1_sin = emb
    .clone()
    .edge(select_ty!(MetalSin<T>).ptr(&mut sin2))
    .edge(x1.clone().edge(select_ty!(MetalMul<T>).ptr(&mut out_mul3)));
let x1_cos = emb
    .clone()
    .edge(select_ty!(MetalCos<T>).ptr(&mut cos2))
    .edge(x1.edge(select_ty!(MetalMul<T>).ptr(&mut out_mul4)));
let x0_out = x1_sin.edge(x0_cos.edge(select_ty!(MetalSub<T>).ptr(&mut out_sub)));
let x1_out = x0_sin.edge(x1_cos.edge(select_ty!(MetalAdd<T>).ptr(&mut out_add)));
let mut searcher = x1_out
    .edge(x0_out.edge(select_ty!(MetalAdd<T>).ptr(&mut final_add)))
    .search(graph);

To

let freqs = binary::<MetalMul<T>>(op::<MetalARange<T>>(), constant::<T>(2.0));
let freqs = binary::<MetalMul<T>>(freqs, op::<MetalConstant<T>>());
let freqs = binary::<MetalMul<T>>(freqs, constant::<T>((1000000_f32).abs().ln()));
let freqs = unary::<MetalRecip<T>>(unary::<MetalExp<T>>(freqs));
let prev_seq = op::<MetalConstant<T>>();
let emb = binary::<MetalMul<T>>(
    binary::<MetalAdd<T>>(op::<MetalARange<T>>(), prev_seq.clone()),
    freqs,
);
let inp = node();
let split = unary::<MetalContiguous<T>>(inp.clone());
let x0 = unary::<MetalContiguous<T>>(split.clone());
let x0_out = binary::<MetalSub<T>>(
    binary::<MetalMul<T>>(x0, unary::<MetalSin<T>>(emb.clone())),
    binary::<MetalMul<T>>(op::<MetalContiguous<T>>(), op::<MetalCos<T>>()),
);
let x1_out = binary::<MetalAdd<T>>(op::<MetalMul<T>>(), op::<MetalMul<T>>());
let add = binary::<MetalAdd<T>>(x0_out, x1_out);
let mut s = add.clone().search(graph);

Still some correctness bugs remaining. The selector graphs currently cannot reference the same node twice, because the backtracking function doesn't support it yet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant