load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")
load("//xla:xla.bzl", "xla_cc_test")
load("//xla/tests:build_defs.bzl", "xla_test")

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    licenses = ["notice"],
)

cc_library(
    name = "in_place_dynamic_update_slice",
    srcs = ["in_place_dynamic_update_slice.cc"],
    hdrs = ["in_place_dynamic_update_slice.h"],
    deps = [
        ":fusion_emitter",
        "//xla:statusor",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu/model:indexing_analysis",
        "//xla/service/llvm_ir:dynamic_update_slice_util",
        "//xla/service/llvm_ir:fused_ir_emitter",
        "//xla/service/llvm_ir:ir_array",
        "@com_google_absl//absl/status",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
    ],
)

xla_cc_test(
    name = "in_place_dynamic_update_slice_test",
    srcs = ["in_place_dynamic_update_slice_test.cc"],
    deps = [
        ":fusions",
        ":in_place_dynamic_update_slice",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:affine_map_printer",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/stream_executor:device_description",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "in_place_dynamic_update_slice_mlir",
    srcs = ["in_place_dynamic_update_slice_mlir.cc"],
    hdrs = ["in_place_dynamic_update_slice_mlir.h"],
    deps = [
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu/fusions/mlir:computation_partitioner",
        "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
        "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
        "//xla/service/gpu/model:indexing_analysis",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TensorDialect",
    ],
)

xla_test(
    name = "in_place_dynamic_update_slice_mlir_test",
    srcs = ["in_place_dynamic_update_slice_mlir_test.cc"],
    backends = ["gpu"],
    deps = [
        ":in_place_dynamic_update_slice_mlir",
        ":mlir_emitter_test_base",
        "//xla:error_spec",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/lib/core:status_test_util",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "copy",
    srcs = ["copy.cc"],
    hdrs = ["copy.h"],
    deps = [
        ":fusion_emitter",
        "//xla:shape_util",
        "//xla/hlo/ir:hlo",
        "//xla/service:buffer_assignment",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:hlo_traversal",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu/runtime:copy_thunk",
        "//xla/service/gpu/runtime:thunk",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "custom",
    srcs = ["custom.cc"],
    hdrs = ["custom.h"],
    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
    deps = [
        ":fusion_emitter",
        "//xla:shape_util",
        "//xla:statusor",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/ffi:attribute_map",
        "//xla/ffi:ffi_api",
        "//xla/hlo/ir:hlo",
        "//xla/service:buffer_assignment",
        "//xla/service:custom_call_status",
        "//xla/service:custom_call_target_registry",
        "//xla/service:hlo_proto_cc",
        "//xla/service:pattern_matcher",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:cublas_cudnn",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:hlo_traversal",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:kernel_arguments",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu/kernels:custom_kernel",
        "//xla/service/gpu/kernels:custom_kernel_fusion",
        "//xla/service/gpu/runtime:custom_call_thunk",
        "//xla/service/gpu/runtime:dynamic_slice_thunk",
        "//xla/service/gpu/runtime:gemm_thunk",
        "//xla/service/gpu/runtime:kernel_thunk",
        "//xla/service/gpu/runtime:thunk",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:AsmParser",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:statusor",
    ],
)

xla_test(
    name = "address_computation_fusion_test",
    srcs = if_cuda_is_configured(["address_computation_fusion_test.cc"]),
    backends = ["gpu"],
    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
    deps = [
        "//xla:error_spec",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/client:xla_builder",
        "//xla/client:xla_computation",
        "//xla/client/lib:constants",
        "//xla/ffi",
        "//xla/ffi:ffi_api",
        "//xla/hlo/ir:hlo",
        "//xla/service:custom_call_target_registry",
        "//xla/service:executable",
        "//xla/service:hlo_module_config",
        "//xla/service:hlo_proto_cc",
        "//xla/service/gpu:address_computation_fusion_rewriter",
        "//xla/stream_executor",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/gpu:gpu_types_header",
        "//xla/tests:hlo_test_base",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/status",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:statusor",
        "@local_tsl//tsl/platform:test",
        "@local_tsl//tsl/platform:test_main",
    ] + if_cuda_is_configured([
        "@local_config_cuda//cuda:cuda_headers",
    ]) + if_rocm_is_configured([
        "@local_config_rocm//rocm:rocm_headers",
    ]),
)

cc_library(
    name = "fusion_emitter",
    srcs = ["fusion_emitter.cc"],
    hdrs = ["fusion_emitter.h"],
    visibility = ["//xla/service/gpu:__subpackages__"],
    deps = [
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:statusor",
        "//xla:util",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:kernel_arguments",
        "//xla/service/gpu:kernel_reuse_cache",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:target_util",
        "//xla/service/gpu/model:indexing_analysis",
        "//xla/service/gpu/runtime:kernel_thunk",
        "//xla/service/gpu/runtime:thunk",
        "//xla/service/llvm_ir:ir_array",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/stream_executor:device_description",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:TargetParser",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "fusions",
    srcs = ["fusions.cc"],
    hdrs = ["fusions.h"],
    visibility = ["//xla/service/gpu:__subpackages__"],
    deps = [
        ":concatenate",
        ":concatenate_mlir",
        ":copy",
        ":cudnn",
        ":custom",
        ":fusion_emitter",
        ":in_place_dynamic_update_slice",
        ":in_place_dynamic_update_slice_mlir",
        ":input_slices",
        ":input_slices_mlir",
        ":loop",
        ":loop_mlir",
        ":reduction",
        ":reduction_mlir",
        ":scatter",
        ":scatter_mlir",
        ":transpose",
        ":transpose_mlir",
        ":triton",
        "//xla:shape_util",
        "//xla/hlo/ir:hlo",
        "//xla/service:buffer_assignment",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:hlo_traversal",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
    ],
)

cc_library(
    name = "mlir_emitter_test_base",
    testonly = True,
    srcs = ["mlir_emitter_test_base.cc"],
    hdrs = ["mlir_emitter_test_base.h"],
    deps = [
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/mlir_hlo",
        "//xla/service:gpu_plugin",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
        "//xla/service/gpu/model:affine_map_printer",
        "//xla/stream_executor:device_description",
        "//xla/tests:filecheck",
        "//xla/tests:hlo_test_base",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:AffineDialect",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:ComplexDialect",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:FuncExtensions",
        "@llvm-project//mlir:GPUDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:MathDialect",
        "@llvm-project//mlir:MemRefTransforms",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:SCFDialect",
        "@llvm-project//mlir:TensorDialect",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "loop",
    srcs = ["loop.cc"],
    hdrs = ["loop.h"],
    deps = [
        ":fusion_emitter",
        "//xla:shape_util",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:gpu_fusible",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:hlo_traversal",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:parallel_loop_emitter",
        "//xla/service/gpu/model:indexing_analysis",
        "//xla/service/llvm_ir:fused_ir_emitter",
        "//xla/service/llvm_ir:ir_array",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/numeric:bits",
        "@com_google_absl//absl/status",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:macros",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "loop_mlir",
    srcs = ["loop_mlir.cc"],
    hdrs = ["loop_mlir.h"],
    deps = [
        ":loop",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu/fusions/mlir:computation_partitioner",
        "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
        "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
        "//xla/service/gpu/model:indexing_analysis",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TensorDialect",
    ],
)

xla_test(
    name = "loop_mlir_test",
    srcs = ["loop_mlir_test.cc"],
    backends = ["gpu"],
    deps = [
        ":loop_mlir",
        ":mlir_emitter_test_base",
        "//xla:error_spec",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/lib/core:status_test_util",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "scatter_mlir",
    srcs = ["scatter_mlir.cc"],
    hdrs = ["scatter_mlir.h"],
    deps = [
        ":loop",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service:scatter_simplifier",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu/fusions/mlir:computation_partitioner",
        "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
        "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
        "//xla/service/gpu/model:indexing_analysis",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:DataLayoutInterfaces",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:SCFDialect",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:TensorDialect",
    ],
)

xla_test(
    name = "scatter_mlir_test",
    srcs = ["scatter_mlir_test.cc"],
    backends = ["gpu"],
    deps = [
        ":mlir_emitter_test_base",
        ":scatter_mlir",
        "//xla:error_spec",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/lib/core:status_test_util",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "transpose_mlir",
    srcs = ["transpose_mlir.cc"],
    hdrs = ["transpose_mlir.h"],
    deps = [
        ":fusion_emitter",
        "//xla:permutation_util",
        "//xla:shape_util",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/mlir/utils:type_util",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu/fusions/mlir:computation_partitioner",
        "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
        "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
        "//xla/service/gpu/model:indexing_analysis",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:TensorDialect",
    ],
)

xla_test(
    name = "transpose_mlir_test",
    srcs = ["transpose_mlir_test.cc"],
    backends = ["gpu"],
    deps = [
        ":mlir_emitter_test_base",
        ":transpose_mlir",
        "//xla:error_spec",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/lib/core:status_test_util",
        "@local_tsl//tsl/platform:statusor",
    ],
)

xla_cc_test(
    name = "loop_test",
    srcs = ["loop_test.cc"],
    deps = [
        ":fusion_emitter",
        ":fusions",
        "//xla:status_macros",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:affine_map_printer",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/stream_executor:device_description",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",
        "@com_google_absl//absl/status:statusor",
        "@com_google_googletest//:gtest",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "scatter",
    srcs = ["scatter.cc"],
    hdrs = ["scatter.h"],
    deps = [
        ":fusion_emitter",
        ":loop",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:parallel_loop_emitter",
        "//xla/service/gpu/model:indexing_analysis",
        "//xla/service/llvm_ir:fused_ir_emitter",
        "//xla/service/llvm_ir:ir_array",
        "//xla/service/llvm_ir:llvm_util",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:statusor",
    ],
)

xla_cc_test(
    name = "scatter_test",
    srcs = ["scatter_test.cc"],
    deps = [
        ":fusions",
        ":scatter",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:affine_map_printer",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/stream_executor:device_description",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "tiling_util",
    srcs = ["tiling_util.cc"],
    hdrs = ["tiling_util.h"],
    visibility = ["//xla/service/gpu:__subpackages__"],
    deps = [
        "//xla:shape_util",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:target_util",
        "//xla/service/llvm_ir:ir_array",
        "//xla/service/llvm_ir:kernel_support_library",
        "//xla/service/llvm_ir:llvm_loop",
        "//xla/service/llvm_ir:llvm_util",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "triton",
    srcs = ["triton.cc"],
    hdrs = ["triton.h"],
    visibility = ["//xla/service/gpu:__subpackages__"],
    deps = [
        ":fusion_emitter",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:statusor",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:hlo_traversal",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:ir_emitter_triton",
        "//xla/service/gpu:kernel_arguments",
        "//xla/service/gpu:kernel_reuse_cache",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:triton_fusion_analysis",
        "//xla/service/gpu/runtime:kernel_thunk",
        "//xla/service/llvm_ir:ir_array",
        "//xla/service/llvm_ir:llvm_util",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:statusor",
    ],
)

xla_cc_test(
    name = "triton_test",
    srcs = ["triton_test.cc"],
    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
    deps = [
        ":fusions",
        ":triton",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "cudnn",
    srcs = ["cudnn.cc"],
    hdrs = ["cudnn.h"],
    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
    deps = [
        ":fusion_emitter",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:kernel_arguments",
        "//xla/service/gpu:kernel_reuse_cache",
        "//xla/service/gpu/runtime:cudnn_thunk",
        "//xla/service/gpu/runtime:thunk",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status:statusor",
        "@local_tsl//tsl/platform:statusor",
    ],
)

xla_test(
    name = "cudnn_test",
    srcs = if_cuda_is_configured(["cudnn_test.cc"]),
    backend_tags = {"gpu": [
        "requires-gpu-sm90",
    ]},
    backends = [
        "gpu",
    ],
    deps = [
        "//xla:comparison_util",
        "//xla:debug_options_flags",
        "//xla:error_spec",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service:executable",
        "//xla/service:pattern_matcher",
        "//xla/service:pattern_matcher_gmock",
        "//xla/service/gpu:cudnn_fusion_compiler",
        "//xla/service/gpu:stream_executor_util",
        "//xla/service/gpu/runtime:thunk",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:stream_executor_headers",
        "//xla/stream_executor:stream_executor_memory_allocator",
        "//xla/tests:filecheck",
        "//xla/tests:verified_hlo_module",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest_main",
        "@local_tsl//tsl/lib/core:status_test_util",
        "@local_tsl//tsl/platform:statusor",
        "@local_tsl//tsl/platform:test_main",
    ],
)

cc_library(
    name = "thunk_util",
    srcs = ["thunk_util.cc"],
    hdrs = ["thunk_util.h"],
    visibility = ["//xla/service/gpu:__subpackages__"],
    deps = [
        "//xla:literal",
        "//xla:shape_util",
        "//xla/hlo/ir:hlo",
        "//xla/service:buffer_assignment",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu/runtime:memset_thunk",
        "//xla/service/gpu/runtime:thunk",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/types:span",
    ],
)

cc_library(
    name = "reduction",
    srcs = ["reduction.cc"],
    hdrs = ["reduction.h"],
    deps = [
        ":fusion_emitter",
        ":reduction_base",
        ":thunk_util",
        ":tiling_util",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service:buffer_assignment",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:hlo_traversal",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:kernel_arguments",
        "//xla/service/gpu:kernel_reuse_cache",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:parallel_loop_emitter",
        "//xla/service/gpu:reduction_utils",
        "//xla/service/gpu:target_util",
        "//xla/service/gpu/runtime:kernel_thunk",
        "//xla/service/gpu/runtime:thunk",
        "//xla/service/llvm_ir:fused_ir_emitter",
        "//xla/service/llvm_ir:ir_array",
        "//xla/service/llvm_ir:kernel_support_library",
        "//xla/service/llvm_ir:llvm_loop",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/service/llvm_ir:loop_emitter",
        "//xla/stream_executor:device_description",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/container:node_hash_map",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:logging",
        "@local_tsl//tsl/platform:status",
        "@local_tsl//tsl/platform:statusor",
    ],
)

xla_cc_test(
    name = "reduction_test",
    srcs = ["reduction_test.cc"],
    deps = [
        ":fusion_emitter",
        ":reduction",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu/model:indexing_analysis",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/stream_executor:device_description",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_googletest//:gtest",
        "@llvm-project//mlir:IR",
    ],
)

cc_library(
    name = "reduction_base",
    srcs = ["reduction_base.cc"],
    hdrs = ["reduction_base.h"],
    deps = [
        ":fusion_emitter",
        ":tiling_util",
        "//xla:shape_util",
        "//xla:union_find",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_query",
        "//xla/service/gpu:gpu_fusible",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:hlo_traversal",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:reduction_utils",
        "//xla/service/gpu/model:indexing_analysis",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:launch_dim",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/container:node_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:IR",
    ],
)

cc_library(
    name = "reduction_mlir",
    srcs = ["reduction_mlir.cc"],
    hdrs = ["reduction_mlir.h"],
    deps = [
        ":fusion_emitter",
        ":reduction_base",
        "//xla:shape_util",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:hlo_traversal",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:reduction_utils",
        "//xla/service/gpu/fusions/mlir:computation_partitioner",
        "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
        "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
        "//xla/service/gpu/fusions/mlir:type_util",
        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
        "//xla/service/gpu/model:indexing_analysis",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:BufferizationInterfaces",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:GPUDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:SCFDialect",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:TensorDialect",
        "@llvm-project//mlir:VectorDialect",
    ],
)

xla_test(
    name = "reduction_mlir_test",
    srcs = ["reduction_mlir_test.cc"],
    backends = ["gpu"],
    deps = [
        ":mlir_emitter_test_base",
        ":reduction_mlir",
        "//xla:error_spec",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/lib/core:status_test_util",
    ],
)

cc_library(
    name = "concatenate",
    srcs = ["concatenate.cc"],
    hdrs = ["concatenate.h"],
    deps = [
        ":fusion_emitter",
        "//xla:shape_util",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:parallel_loop_emitter",
        "//xla/service/gpu/model:indexing_analysis",
        "//xla/service/llvm_ir:fused_ir_emitter",
        "//xla/service/llvm_ir:ir_array",
        "//xla/service/llvm_ir:loop_emitter",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/status",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:statusor",
    ],
)

xla_cc_test(
    name = "concatenate_test",
    srcs = ["concatenate_test.cc"],
    deps = [
        ":concatenate",
        ":fusions",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:affine_map_printer",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/stream_executor:device_description",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "concatenate_mlir",
    srcs = ["concatenate_mlir.cc"],
    hdrs = ["concatenate_mlir.h"],
    deps = [
        ":concatenate",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu/fusions/mlir:computation_partitioner",
        "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
        "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
        "//xla/service/gpu/model:indexing_analysis",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TensorDialect",
    ],
)

xla_test(
    name = "concatenate_mlir_test",
    srcs = ["concatenate_mlir_test.cc"],
    backends = ["gpu"],
    deps = [
        ":concatenate_mlir",
        ":mlir_emitter_test_base",
        "//xla:error_spec",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/lib/core:status_test_util",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "transpose",
    srcs = ["transpose.cc"],
    hdrs = ["transpose.h"],
    deps = [
        ":fusion_emitter",
        ":tiling_util",
        "//xla:permutation_util",
        "//xla:shape_util",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:target_util",
        "//xla/service/gpu/model:indexing_analysis",
        "//xla/service/llvm_ir:fused_ir_emitter",
        "//xla/service/llvm_ir:ir_array",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/service/llvm_ir:loop_emitter",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:statusor",
    ],
)

xla_cc_test(
    name = "transpose_test",
    srcs = ["transpose_test.cc"],
    deps = [
        ":fusions",
        ":transpose",
        "//xla:status_macros",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/stream_executor:device_description",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",
        "@com_google_absl//absl/status:statusor",
        "@com_google_googletest//:gtest",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "input_slices",
    srcs = ["input_slices.cc"],
    hdrs = ["input_slices.h"],
    deps = [
        ":fusion_emitter",
        "//xla:shape_util",
        "//xla:util",
        "//xla/hlo/ir:hlo",
        "//xla/service:elemental_ir_emitter",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:parallel_loop_emitter",
        "//xla/service/gpu/model:indexing_analysis",
        "//xla/service/llvm_ir:fused_ir_emitter",
        "//xla/service/llvm_ir:ir_array",
        "//xla/service/llvm_ir:kernel_support_library",
        "//xla/service/llvm_ir:llvm_loop",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "input_slices_mlir",
    srcs = ["input_slices_mlir.cc"],
    hdrs = ["input_slices_mlir.h"],
    deps = [
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:hlo_traversal",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu/fusions/mlir:computation_partitioner",
        "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
        "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
        "//xla/service/gpu/model:indexing_analysis",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:SCFDialect",
        "@llvm-project//mlir:TensorDialect",
    ],
)

xla_test(
    name = "input_slices_mlir_test",
    srcs = ["input_slices_mlir_test.cc"],
    backends = ["gpu"],
    deps = [
        ":input_slices_mlir",
        ":mlir_emitter_test_base",
        "//xla:error_spec",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/lib/core:status_test_util",
    ],
)

xla_cc_test(
    name = "input_slices_test",
    srcs = ["input_slices_test.cc"],
    deps = [
        ":fusions",
        ":input_slices",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu/model:affine_map_printer",
        "//xla/service/gpu/model:indexing_test_utils",
        "//xla/stream_executor:device_description",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",
        "@com_google_googletest//:gtest",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:statusor",
    ],
)
