[ROCm] gpt-oss: route FA3 to aiter-flash-attn, generate ROCm fixtures#46837
[ROCm] gpt-oss: route FA3 to aiter-flash-attn, generate ROCm fixtures#46837Abdennacer-Badaoui wants to merge 7 commits into
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
vasqu
left a comment
There was a problem hiding this comment.
Have a few comments, especially hesitant on the tests side because we change a few things that would break the original cuda fixtures so gotta be a bit careful
71f135c to
cae1c1f
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gpt_oss, openai_privacy_filter |
| # Generate key to look up expected outputs | ||
| key = generate_config_key(quantized, model_size, kernels, attn_impl, mode) | ||
|
|
||
| if os.environ.get("WRITE_FIXTURES") == "1": |
There was a problem hiding this comment.
Yes ahah, i was about to remove it , thanks
|
CI Dashboard: View test results in Grafana |
| "device=rocm|quantized=false|model=20b|kernels=false|attn_impl=kernels-community/aiter-flash-attn|mode=eval": [ | ||
| "Roses are red, violets, vi, vi, vi, vi, vi, vi, vi, vi, vi, vi", | ||
| "How are you? Tell me the name of the president of the president of the name of the president of the name of the president of the name of the president" | ||
| ], | ||
| "device=rocm|quantized=false|model=20b|kernels=false|attn_impl=kernels-community/aiter-flash-attn|mode=train": [ | ||
| "Roses are red, violets, vi, vi, vi, vi, vi, vi, vi, vi, vi, vi", | ||
| "How are you? Tell me the name of the president of the president of the name of the president of the name of the president of the name of the president" | ||
| ], | ||
| "device=rocm|quantized=true|model=20b|kernels=false|attn_impl=kernels-community/aiter-flash-attn|mode=eval": [ | ||
| "Roses are red, violets, vi, vi, vi, vi, vi, vi, vi, vi, vi, vi", | ||
| "How are you? Tell me the name of the president of the president of the name of the president of the name of the president of the name of the president" | ||
| ], | ||
| "device=rocm|quantized=true|model=20b|kernels=false|attn_impl=kernels-community/aiter-flash-attn|mode=train": [ | ||
| "Roses are red, violets, vi, vi, vi, vi, vi, vi, vi, vi, vi, vi", | ||
| "How are you? Tell me the name of the president of the president of the name of the president of the name of the president of the name of the president" | ||
| ], | ||
| "device=rocm|quantized=false|model=120b|kernels=false|attn_impl=kernels-community/aiter-flash-attn|mode=eval": [ | ||
| "Roses are red, violets red, red, red, red, red, red,,,,,,,,,", | ||
| "How are you? Tell me the name of the president of the\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" | ||
| ], | ||
| "device=rocm|quantized=false|model=120b|kernels=false|attn_impl=kernels-community/aiter-flash-attn|mode=train": [ | ||
| "Roses are red, violets red, red, red, red, red, red,,,,,,,,,", | ||
| "How are you? Tell me the name of the president of the\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" | ||
| ], | ||
| "device=rocm|quantized=true|model=120b|kernels=false|attn_impl=kernels-community/aiter-flash-attn|mode=eval": [ | ||
| "Roses are red, violets red, red, red, red, red, red,,,,,,,,,", | ||
| "How are you? Tell me the name of the president of the\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" | ||
| ], | ||
| "device=rocm|quantized=true|model=120b|kernels=false|attn_impl=kernels-community/aiter-flash-attn|mode=train": [ | ||
| "Roses are red, violets red, red, red, red, red, red,,,,,,,,,", | ||
| "How are you? Tell me the name of the president of the\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" |
There was a problem hiding this comment.
ok sorry to intervene again, I just noticed that we have significantly different output between FA and eager attention. this smells fishy to me - are we sure that something is not broken?
Especially these repeating outputs are weird. Is there a different naming convention for the s_aux for example?
There was a problem hiding this comment.
Good catch! Thanks.
I didn't check the fixtures as i was generating them directly to the file.
Let me check what's happening here.
There was a problem hiding this comment.
there's a real bug. transformers passes the attention sinks to the FA kernel under the name s_aux (what vllm-fa3 expects) or learnable_sink (what FA4 expects). Our aiter-flash-attn kernel calls the same argument sink, which transformers doesn't recognize as a sink name, so it never passes the sink tensor to the kernel. The attention then runs without sinks, which is why we see the repetitive degenerate output. I'll rename the kernel's public arg to s_aux so transformers picks it up, then regenerate the ROCm FA fixtures.
There was a problem hiding this comment.
Noticed significanlty different outputs across eager vs fa
|
This PR allows multiple different names for the same argument (especially for sinks): #45153. |
On ROCm, gpt-oss now routes the FA3-style attention to
kernels-community/aiter-flash-attn(_compatible_flash_implementations) and generates separate ROCm fixtures (non-distributed +tp_size=2). Configs that depend onkernels-community/megablocksare skipped because it doesn't ship a ROCm build for torch 2.10 (the version the AMD CI runs today) we have the ones for 2.11 and 2.12, and configs that hitkernels-community/sonic-moeare skipped since that kernel has no ROCm build yet.Fixes around 140 failing tests.