Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] Not supported on A6000? #46

Open
Zhuohao-Li opened this issue Oct 26, 2024 · 3 comments
Open

[QUESTION] Not supported on A6000? #46

Zhuohao-Li opened this issue Oct 26, 2024 · 3 comments
Assignees

Comments

@Zhuohao-Li
Copy link

Your question
Hi,

When I run the test demo with a node consists of 2 A6000 it reports bugs:

RuntimeError: /root/opensource/flux/src/cuda/op_registry.cu:36 Check failed: arch_num == 80 || arch_num == 89 || arch_num == 90. unsupported arch: 86
So flux can only support these three GPUs (cc=90, 80, 89), correct me if I misunderstand it.

Thanks

@zheng-ningxin zheng-ningxin self-assigned this Oct 28, 2024
@zheng-ningxin
Copy link
Collaborator

Yes, Flux only compiled the architectures 80, 89, and 90 for now. However, I suspect that CUTLASS v2 should directly support architecture number 86. Could you try adding the corresponding arch number and recompiling to see if it works?

@Zhuohao-Li
Copy link
Author

Zhuohao-Li commented Oct 28, 2024

Thanks,

I add 86 arguments to /flux/src/cuda/op_registery.cu line 36 like this:

void
init_arch_tag() {
  int major, minor;
  cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
  cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
  int arch_num = major * 10 + minor;
  FLUX_CHECK(arch_num == 80 || arch_num == 89 || arch_num == 90 || arch_num == 86)
      << "unsupported arch: " << arch_num;
  arch = ArchEnum{arch_num};
}
} 

I recompiled it via Build from Source again, but when running ./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10 it turns out:

RuntimeError: ~/flux/include/flux/op_registry.h:220 Check failed: visit_iter != gemm_hparams_.end(). No registered hparams found for meta:GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32),arch=UNK,comm_op=CommNone,gemm_layout=RCR,impl=GemmV2,impl_spec=GemmV2Meta(fast_accum=0),comm_spec=None)

The corresponding code pieces in op_registry.h:

// Iterate all hparams registered for a meta and call func.
  // This can be useful for tuning.
  template <class... Ts>
  void
  visit_hparams(std::function<void(UnifiedGemmHParams)> &&func, GemmMeta<Ts...> meta) {
    std::shared_lock<std::shared_mutex> lock(register_mutex_);
    auto unified_meta = unify_type(meta);
    auto iter = gemm_hparams_.find(unified_meta);
    FLUX_CHECK(iter != gemm_hparams_.end()) << "no op registered for meta:" << meta;
    for (const auto &hparams_pair : iter->second) {
      auto const &hparams = hparams_pair.second;
      func(hparams);
    }
  }

I have not yet took a deep look at what is hparams is, if possible, can you please point it to me quickly? Any additional changes in the codebase? Thanks!

@houqi
Copy link
Collaborator

houqi commented Nov 8, 2024

Thanks,

I add 86 arguments to /flux/src/cuda/op_registery.cu line 36 like this:

void
init_arch_tag() {
  int major, minor;
  cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
  cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
  int arch_num = major * 10 + minor;
  FLUX_CHECK(arch_num == 80 || arch_num == 89 || arch_num == 90 || arch_num == 86)
      << "unsupported arch: " << arch_num;
  arch = ArchEnum{arch_num};
}
} 

I recompiled it via Build from Source again, but when running ./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10 it turns out:

RuntimeError: ~/flux/include/flux/op_registry.h:220 Check failed: visit_iter != gemm_hparams_.end(). No registered hparams found for meta:GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32),arch=UNK,comm_op=CommNone,gemm_layout=RCR,impl=GemmV2,impl_spec=GemmV2Meta(fast_accum=0),comm_spec=None)

The corresponding code pieces in op_registry.h:

// Iterate all hparams registered for a meta and call func.
  // This can be useful for tuning.
  template <class... Ts>
  void
  visit_hparams(std::function<void(UnifiedGemmHParams)> &&func, GemmMeta<Ts...> meta) {
    std::shared_lock<std::shared_mutex> lock(register_mutex_);
    auto unified_meta = unify_type(meta);
    auto iter = gemm_hparams_.find(unified_meta);
    FLUX_CHECK(iter != gemm_hparams_.end()) << "no op registered for meta:" << meta;
    for (const auto &hparams_pair : iter->second) {
      auto const &hparams = hparams_pair.second;
      func(hparams);
    }
  }

I have not yet took a deep look at what is hparams is, if possible, can you please point it to me quickly? Any additional changes in the codebase? Thanks!

You should also

  1. modify flux.h and add to ArchEnum with 86
  2. add into workspace with sm86: gemm_v2_reduce_scatter.hpp#L502 for GRMM+RS, gemm_v2_ag_kernel.hpp#L174 for AG+GEMM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants