GithubHelp home page GithubHelp logo

cornell-zhang / allo Goto Github PK

View Code? Open in Web Editor NEW
70.0 11.0 6.0 3.59 MB

Allo: A Programming Model for Composable Accelerator Design

Home Page: https://cornell-zhang.github.io/allo

License: Apache License 2.0

Python 95.02% Shell 0.48% Makefile 0.13% Tcl 0.25% C++ 3.03% C 0.11% MLIR 0.98%
compiler dsl fpga fpga-programming hardware high-level-synthesis adl asic

allo's Introduction

Accelerator Design Language

Documentation

GitHub CircleCI

Allo is an Accelerator Design Language (ADL) and compiler that facilitates the construction of large-scale, high-performance hardware accelerators in a modular and composable manner. Allo has several key features:

  • Progressive hardware customizations: Allo decouples hardware customizations from algorithm specifications and treats each hardware customization as a primitive that performs a rewrite on the program. Allo not only decouples the loop-based transformations, but also extends the decoupling to memory, communication, and data types.
  • Reusable parameterized kernel templates: Allo supports declaring type variables during kernel creation and instantiating the kernel when building the hardware executable, which is an important feature for building reusable hardware kernel libraries. Allo introduces a concise grammar for creating kernel templates, eliminating the need for users to possess complicated metaprogramming expertise.
  • Composable schedules: Allo empowers users to construct kernels incrementally from the bottom up, adding customizations one at a time while validating the correctness of each submodule. Ultimately, multiple schedules are progressively integrated into a complete design using the .compose() primitive. This approach, unachievable by prior top-down methods, significantly enhances productivity and debuggability.

Installation

Please clone the Allo repository to your local machine.

git clone https://github.com/cornell-zhang/allo.git
cd allo

We recommend creating a new conda environment for Allo. Since we are using the latest Python features, the minimum Python version is 3.12.

conda create -n allo python=3.12
conda activate allo

Prerequisites

We need to first install the LLVM project and the hcl-mlir dialect. Users can choose to use our provided docker or build from source.

Docker

To simplify the installation process, we provide a docker image that has already installed the LLVM-18.x project. Please pull the image from Docker Hub, patch LLVM, and install the hcl dialect as described above.

# * The LLVM is installed in /root/llvm-project in the docker image, which has already been patched
# * A prebuilt hcl-dialect is installed in /root/hcl-dialect, but please note that it is not up-to-date
#   You can pull the latest hcl-dialect using `git pull` and rebuild it if needed
docker pull chhzh123/hcl-dialect:llvm-18.x-py3.12
docker run --rm -it chhzh123/hcl-dialect:llvm-18.x-py3.12

Build from source

Users can also choose to build LLVM and the hcl dialect from source. Please follow the instructions below.

# Make sure you are under the correct Python environment
bash build.sh

Install Allo

After installing LLVM and the hcl dialect, we can directly pip install Allo:

# Under the root directory of Allo
python3 -m pip install -e .

Getting Started

Below is a minimal example of leveraging Allo to customize a GEMM kernel:

import allo
from allo.ir.types import int32

# Allo kernel definition
def gemm(A: int32[32, 32], B: int32[32, 32]) -> int32[32, 32]:
    C: int32[32, 32] = 0
    for i, j, k in allo.grid(32, 32, 32):
        C[i, j] += A[i, k] * B[k, j]
    return C

# Schedule construction
s = allo.customize(gemm)

# Real-time transformation
s.split("i", factor=8)
print(s.module)

# Compilation
mod = s.build(target="llvm")

# Execution
import numpy as np
np_A = np.random.randint(0, 100, (32, 32)).astype(np.int32)
np_B = np.random.randint(0, 100, (32, 32)).astype(np.int32)
np_C = mod(np_A, np_B)

# Testing
golden_C = np.matmul(np_A, np_B)
np.testing.assert_allclose(np_C, golden_C, rtol=1e-5, atol=1e-5)

Publications

Please refer to our PLDI'24 paper for more details. If you use Allo in your research, please use the following bibtex entry to cite us:

@article{chen2024allo,
    author = {Hongzheng Chen and Niansong Zhang and Shaojie Xiang and Zhichen Zeng and Mengjia Dai and Zhiru Zhang},
    title = {Allo: A Programming Model for Composable Accelerator Design},
    journal = {Proc. ACM Program. Lang.},
    year = {2024},
    month = {jun},
    url = {https://doi.org/10.1145/3656401},
    doi = {10.1145/3656401},
    articleno = {171},
    volume = {8},
    number = {PLDI},
    publisher = {Association for Computing Machinery},
    address = {New York, NY, USA},
    issue_date = {June 2024},
}

Related Projects

allo's People

Contributors

chhzh123 avatar zhichenzzz avatar zzzdavid avatar mmengjiadai avatar crazy-james26 avatar en-tropy avatar mludens avatar thezenelson avatar

Stargazers

Hafiz avatar  avatar  avatar Yifan Shi avatar  avatar Ryuta Suzuki avatar Brian Park avatar Minseong Jang avatar KKKACAUTO avatar Jai Gupta avatar Jie Tong avatar Xingyan Chen avatar haowen hou avatar Logan Grasby avatar Shixuan Sun avatar hongshi avatar  avatar Oswald(Zifan) He avatar Shengyu Fan avatar FeiyangChen avatar  avatar  avatar Nimalan avatar Howard Lau avatar Da Cui avatar Elton avatar Weichuang Zhang avatar Lu Ming avatar Ziyi Tan avatar ChengXiang Qi avatar Saul Shanabrook avatar Kevin Laeufer avatar Jonathan Ragan-Kelley avatar hecmay avatar Lambda Shi  avatar Lei Wang avatar  avatar Jie Liu avatar  avatar Yixiao Du avatar Jordan Dotzel avatar  avatar Yaohui Cai avatar Vimal William avatar Zijian Ding avatar Andre Slavescu avatar Jianyi Cheng avatar Yin Shuo avatar Yuka Ikarashi avatar  avatar Lee Man avatar Tiancheng Xu avatar Jinyu Bai avatar Cunxi Yu avatar Jiacheng Pan avatar Ruijie (Jerry) Gao avatar Peipei Zhou's Lab at Pitt-ECE avatar Haiyan avatar Jinming Zhuang avatar Andrew Butt avatar Sandalots avatar Zihao Ye avatar  avatar KAI_kyle avatar Rachit Nigam avatar Matt Hofmann avatar Jiajie Li avatar Hanchen Ye avatar Zhiru Zhang avatar  avatar

Watchers

Adrian Sampson avatar Zhiru Zhang avatar Hanchen Jin avatar Andrew Butt avatar Hanchen Ye avatar  avatar  avatar Shengyu Fan avatar  avatar Suhail Basalama avatar Debjit Pal avatar

allo's Issues

[BUG] Generate incorrect MLIR #map relationship when reusing some sub-functions with .parition() and .compose() operations

Describe the bug
When the user wants to reuse some sub-functions with .partition() customazation in the top function, and then use .compose() to combine all the customazation, it would sometimes generate incorrect MLIR #map relationship and lead to compilation failure.

To Reproduce
For example, this code can run successfully and generate correct MLIR output:

def test_reuse_function_1():
    M, N = 2, 2

    def matrix_addi(A: int32[M, N]) -> int32[M, N]:
        B: int32[M, N]
        for i, j in allo.grid(M, N):
            B[i, j] = A[i, j] + 1
        return B
    s_addi = allo.customize(matrix_addi)
    s_addi.partition(s_addi.A)

    def matrix_subi(A: int32[M, N]) -> int32[M, N]:
        B: int32[M, N]
        for i, j in allo.grid(M, N):
            B[i, j] = A[i, j] - 1
        return B

    def top(inp: int32[M, N]) -> int32[M, N]:
        temp1 = matrix_addi(inp)
        temp2 = matrix_subi(temp1)
        outp = matrix_addi(temp2)
        return outp

    s = allo.customize(top)
    s.compose(s_addi)
    print(s.module)

Correct output:

#map = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
module {
  func.func @matrix_addi(%arg0: memref<2x2xi32, #map>) -> memref<2x2xi32> {
    %alloc = memref.alloc() {name = "B"} : memref<2x2xi32>
    affine.for %arg1 = 0 to 2 {
      affine.for %arg2 = 0 to 2 {
        %0 = affine.load %arg0[%arg1, %arg2] {from = "A"} : memref<2x2xi32, #map>
        %1 = arith.extsi %0 : i32 to i33
        %c1_i32 = arith.constant 1 : i32
        %2 = arith.extsi %c1_i32 : i32 to i33
        %3 = arith.addi %1, %2 : i33
        %4 = arith.trunci %3 : i33 to i32
        affine.store %4, %alloc[%arg1, %arg2] {to = "B"} : memref<2x2xi32>
      } {loop_name = "j"}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<2x2xi32>
  }
  func.func @matrix_subi(%arg0: memref<2x2xi32>) -> memref<2x2xi32, #map> {
    %alloc = memref.alloc() {name = "B"} : memref<2x2xi32, #map>
    affine.for %arg1 = 0 to 2 {
      affine.for %arg2 = 0 to 2 {
        %0 = affine.load %arg0[%arg1, %arg2] {from = "A"} : memref<2x2xi32>
        %1 = arith.extsi %0 : i32 to i33
        %c1_i32 = arith.constant 1 : i32
        %2 = arith.extsi %c1_i32 : i32 to i33
        %3 = arith.subi %1, %2 : i33
        %4 = arith.trunci %3 : i33 to i32
        affine.store %4, %alloc[%arg1, %arg2] {to = "B"} : memref<2x2xi32, #map>
      } {loop_name = "j"}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<2x2xi32, #map>
  }
  func.func @top(%arg0: memref<2x2xi32, #map>) -> memref<2x2xi32> {
    %0 = call @matrix_addi(%arg0) {name = "temp1"} : (memref<2x2xi32, #map>) -> memref<2x2xi32>
    %1 = call @matrix_subi(%0) {name = "temp2"} : (memref<2x2xi32>) -> memref<2x2xi32, #map>
    %2 = call @matrix_addi(%1) {name = "outp"} : (memref<2x2xi32, #map>) -> memref<2x2xi32>
    return %2 : memref<2x2xi32>
  }
}

But if the user want to do some customization to the funciton matrix_subi() :

    def matrix_subi(A: int32[M, N]) -> int32[M, N]:
            B: int32[M, N]
            for i, j in allo.grid(M, N):
                B[i, j] = A[i, j] - 1
            return B
    s_subi = allo.customize(matrix_subi)

and then uses s.compose(s_subi) to combine the customization:

    s = allo.customize(top)
    s.compose(s_addi)
    s.compose(s_subi)
    print(s.module)

It would generate incorrect #map results and report error:

Error: failed to run MLIR lower pipeline, printing module...
#map = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<() -> (0)>
#map3 = affine_map<() -> (2)>
"builtin.module"() ({
  "func.func"() <{function_type = (memref<2x2xi32, #map>) -> memref<2x2xi32>, sym_name = "matrix_addi"}> ({
  ^bb0(%arg0: memref<2x2xi32, #map>):
    %0 = "memref.alloc"() <{odsOperandSegmentSizes = array<i32: 0, 0>}> {name = "B"} : () -> memref<2x2xi32>
    "affine.for"() ({
    ^bb0(%arg1: index):
      "affine.for"() ({
      ^bb0(%arg2: index):
        %1 = "affine.load"(%arg0, %arg1, %arg2) {from = "A", map = #map1} : (memref<2x2xi32, #map>, index, index) -> i32
        %2 = "arith.extsi"(%1) : (i32) -> i33
        %3 = "arith.constant"() <{value = 1 : i32}> : () -> i32
        %4 = "arith.extsi"(%3) : (i32) -> i33
        %5 = "arith.addi"(%2, %4) : (i33, i33) -> i33
        %6 = "arith.trunci"(%5) : (i33) -> i32
        "affine.store"(%6, %0, %arg1, %arg2) {map = #map1, to = "B"} : (i32, memref<2x2xi32>, index, index) -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map2, step = 1 : index, upper_bound = #map3} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map2, op_name = "S_i_j_0", step = 1 : index, upper_bound = #map3} : () -> ()
    "func.return"(%0) : (memref<2x2xi32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32>) -> memref<2x2xi32>, sym_name = "matrix_subi"}> ({
  ^bb0(%arg0: memref<2x2xi32>):
    %0 = "memref.alloc"() <{odsOperandSegmentSizes = array<i32: 0, 0>}> {name = "B"} : () -> memref<2x2xi32>
    "affine.for"() ({
    ^bb0(%arg1: index):
      "affine.for"() ({
      ^bb0(%arg2: index):
        %1 = "affine.load"(%arg0, %arg1, %arg2) {from = "A", map = #map1} : (memref<2x2xi32>, index, index) -> i32
        %2 = "arith.extsi"(%1) : (i32) -> i33
        %3 = "arith.constant"() <{value = 1 : i32}> : () -> i32
        %4 = "arith.extsi"(%3) : (i32) -> i33
        %5 = "arith.subi"(%2, %4) : (i33, i33) -> i33
        %6 = "arith.trunci"(%5) : (i33) -> i32
        "affine.store"(%6, %0, %arg1, %arg2) {map = #map1, to = "B"} : (i32, memref<2x2xi32>, index, index) -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map2, step = 1 : index, upper_bound = #map3} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map2, op_name = "S_i_j_0", step = 1 : index, upper_bound = #map3} : () -> ()
    "func.return"(%0) : (memref<2x2xi32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32, #map>) -> memref<2x2xi32>, sym_name = "top"}> ({
  ^bb0(%arg0: memref<2x2xi32, #map>):
    %0 = "func.call"(%arg0) <{callee = @matrix_addi}> {name = "temp1"} : (memref<2x2xi32, #map>) -> memref<2x2xi32>
    %1 = "func.call"(%0) <{callee = @matrix_subi}> {name = "temp2"} : (memref<2x2xi32>) -> memref<2x2xi32, #map>
    %2 = "func.call"(%1) <{callee = @matrix_addi}> {name = "outp"} : (memref<2x2xi32, #map>) -> memref<2x2xi32>
    "func.return"(%2) : (memref<2x2xi32>) -> ()
  }) : () -> ()
}) : () -> ()

Traceback (most recent call last):
  File "test_testing.py", line 394, in <module>
    test_reuse_function_1()
  File "test_testing.py", line 340, in test_reuse_function_1
    s.compose(s_subi)
  File "/home/jz2292/allo/allo/customize.py", line 77, in wrapper
    _mlir_lower_pipeline(sch.module)
  File "/home/jz2292/allo/allo/build_module.py", line 30, in _mlir_lower_pipeline
    raise e
  File "/home/jz2292/allo/allo/build_module.py", line 25, in _mlir_lower_pipeline
    mlir_pass_manager.parse(pipeline).run(module.operation)
hcl_mlir._mlir_libs.MLIRError: Failure while executing pass pipeline:
error: unknown: 'func.call' op result type mismatch at index 0
 note: unknown: see current operation: %1 = "func.call"(%0) <{callee = @matrix_subi}> {name = "temp2"} : (memref<2x2xi32>) -> memref<2x2xi32, affine_map<(d0, d1) -> (d0, d1, 0, 0)>>
 note: unknown:       op result types: 'memref<2x2xi32, affine_map<(d0, d1) -> (d0, d1, 0, 0)>>'
 note: unknown: function result types: 'memref<2x2xi32>'

This problem can be fixed by adding the customization code s_subi.partition(s_subi.B) and generate correct output again:

    def matrix_subi(A: int32[M, N]) -> int32[M, N]:
            B: int32[M, N]
            for i, j in allo.grid(M, N):
                B[i, j] = A[i, j] - 1
            return B
    s_subi = allo.customize(matrix_subi)
    s_subi.partition(s_subi.B)

However, in a more complex situation, this method would be in vain. For example:

def test_reuse_function_2():
    M, N = 2, 2

    def matrix_addi(A: int32[M, N]) -> int32[M, N]:
        B: int32[M, N]
        for i, j in allo.grid(M, N):
            B[i, j] = A[i, j] + 1
        return B
    s_addi = allo.customize(matrix_addi)
    s_addi.partition(s_addi.A)

    def matrix_subi(A: int32[M, N]) -> int32[M, N]:
        B: int32[M, N]
        for i, j in allo.grid(M, N):
            B[i, j] = A[i, j] - 1
        return B
    s_subi = allo.customize(matrix_subi)
    s_subi.partition(s_subi.B)

    def top(inp: int32[M, N]) -> int32[M, N]:
        temp1 = matrix_addi(inp)
        temp2 = matrix_subi(temp1)
        temp3 = matrix_addi(temp2)
        outp = matrix_subi(temp3)
        return outp

    s = allo.customize(top)
    s.partition(s.outp)
    s.compose(s_addi)
    s.compose(s_subi)
    print(s.module)

This code would generate a wrong #map relationship and report error again:

Error: failed to run MLIR lower pipeline, printing module...
#map = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<() -> (0)>
#map3 = affine_map<() -> (2)>
"builtin.module"() ({
  "func.func"() <{function_type = (memref<2x2xi32, #map>) -> memref<2x2xi32>, sym_name = "matrix_addi"}> ({
  ^bb0(%arg0: memref<2x2xi32, #map>):
    %0 = "memref.alloc"() <{odsOperandSegmentSizes = array<i32: 0, 0>}> {name = "B"} : () -> memref<2x2xi32>
    "affine.for"() ({
    ^bb0(%arg1: index):
      "affine.for"() ({
      ^bb0(%arg2: index):
        %1 = "affine.load"(%arg0, %arg1, %arg2) {from = "A", map = #map1} : (memref<2x2xi32, #map>, index, index) -> i32
        %2 = "arith.extsi"(%1) : (i32) -> i33
        %3 = "arith.constant"() <{value = 1 : i32}> : () -> i32
        %4 = "arith.extsi"(%3) : (i32) -> i33
        %5 = "arith.addi"(%2, %4) : (i33, i33) -> i33
        %6 = "arith.trunci"(%5) : (i33) -> i32
        "affine.store"(%6, %0, %arg1, %arg2) {map = #map1, to = "B"} : (i32, memref<2x2xi32>, index, index) -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map2, step = 1 : i32, upper_bound = #map3} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map2, op_name = "S_i_j_0", step = 1 : i32, upper_bound = #map3} : () -> ()
    "func.return"(%0) : (memref<2x2xi32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32>) -> memref<2x2xi32, #map>, sym_name = "matrix_subi"}> ({
  ^bb0(%arg0: memref<2x2xi32>):
    %0 = "memref.alloc"() <{odsOperandSegmentSizes = array<i32: 0, 0>}> {name = "B"} : () -> memref<2x2xi32, #map>
    "affine.for"() ({
    ^bb0(%arg1: index):
      "affine.for"() ({
      ^bb0(%arg2: index):
        %1 = "affine.load"(%arg0, %arg1, %arg2) {from = "A", map = #map1} : (memref<2x2xi32>, index, index) -> i32
        %2 = "arith.extsi"(%1) : (i32) -> i33
        %3 = "arith.constant"() <{value = 1 : i32}> : () -> i32
        %4 = "arith.extsi"(%3) : (i32) -> i33
        %5 = "arith.subi"(%2, %4) : (i33, i33) -> i33
        %6 = "arith.trunci"(%5) : (i33) -> i32
        "affine.store"(%6, %0, %arg1, %arg2) {map = #map1, to = "B"} : (i32, memref<2x2xi32, #map>, index, index) -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map2, step = 1 : i32, upper_bound = #map3} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map2, op_name = "S_i_j_0", step = 1 : i32, upper_bound = #map3} : () -> ()
    "func.return"(%0) : (memref<2x2xi32, #map>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32>) -> memref<2x2xi32, #map>, sym_name = "top"}> ({
  ^bb0(%arg0: memref<2x2xi32>):
    %0 = "func.call"(%arg0) <{callee = @matrix_addi}> {name = "temp1"} : (memref<2x2xi32>) -> memref<2x2xi32>
    %1 = "func.call"(%0) <{callee = @matrix_subi}> {name = "temp2"} : (memref<2x2xi32>) -> memref<2x2xi32, #map>
    %2 = "func.call"(%1) <{callee = @matrix_addi}> {name = "temp3"} : (memref<2x2xi32, #map>) -> memref<2x2xi32>
    %3 = "func.call"(%2) <{callee = @matrix_subi}> {name = "outp"} : (memref<2x2xi32>) -> memref<2x2xi32, #map>
    "func.return"(%3) : (memref<2x2xi32, #map>) -> ()
  }) : () -> ()
}) : () -> ()

Traceback (most recent call last):
  File "test_testing.py", line 395, in <module>
    test_reuse_function_2()
  File "test_testing.py", line 375, in test_reuse_function_2
    s.partition(s.outp)
  File "/home/jz2292/allo/allo/customize.py", line 77, in wrapper
    _mlir_lower_pipeline(sch.module)
  File "/home/jz2292/allo/allo/build_module.py", line 30, in _mlir_lower_pipeline
    raise e
  File "/home/jz2292/allo/allo/build_module.py", line 25, in _mlir_lower_pipeline
    mlir_pass_manager.parse(pipeline).run(module.operation)
hcl_mlir._mlir_libs.MLIRError: Failure while executing pass pipeline:
error: unknown: 'func.call' op operand type mismatch: expected operand type 'memref<2x2xi32, affine_map<(d0, d1) -> (d0, d1, 0, 0)>>', but provided 'memref<2x2xi32>' for operand number 0
 note: unknown: see current operation: %0 = "func.call"(%arg0) <{callee = @matrix_addi}> {name = "temp1"} : (memref<2x2xi32>) -> memref<2x2xi32>

Expected behavior
It's expected that the user does not need to add s_subi.partition(s_subi.B) to generate correct result. The compiler should detect this automatically.

Additional context
Actually, I think it is somehow because that the code "s.compose(s_subi)" would affect s.compose(s_addi) and overwrite some MLIR code of the part of matrix_subi() function.

[BUG] Generate incorrect MLIR #map relationship when .partition() being used in nested function call

Describe the bug
When the user wants to use .partition() to partition variables in nested function call, it would generate incorrect MLIR #map relationship and lead to compilation failure.

To Reproduce
example:

def test_nested_compose_partition():
    M, N = 2, 2
    def matrix_addi(A: int32[M, N]) -> int32[M, N]:
        B: int32[M, N]
        for i, j in allo.grid(M, N):
            B[i, j] = A[i, j] + 1
        return B
    s_addi = allo.customize(matrix_addi)
    s_addi.partition(s_addi.A)

    def matrix_addi_top(A: int32[M, N]) -> int32[M, N]:
        B = matrix_addi(A)
        return B
    s_addi_top = allo.customize(matrix_addi_top)
    s_addi_top.compose(s_addi)

    def top(inp: int32[M, N]) -> int32[M, N]:
        outp = matrix_addi_top(inp)
        return outp
    
    s = allo.customize(top)
    # s.partition(s.inp)
    s.compose(s_addi_top)
    print(s.module)

test_nested_compose_partition()

Buggy output
Without the code s.partition(s.inp), it would report error:

Error: failed to run MLIR lower pipeline, printing module...
#map = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<() -> (0)>
#map3 = affine_map<() -> (2)>
"builtin.module"() ({
  "func.func"() <{function_type = (memref<2x2xi32, #map>) -> memref<2x2xi32>, sym_name = "matrix_addi"}> ({
  ^bb0(%arg0: memref<2x2xi32, #map>):
    %0 = "memref.alloc"() <{odsOperandSegmentSizes = array<i32: 0, 0>}> {name = "B"} : () -> memref<2x2xi32>
    "affine.for"() ({
    ^bb0(%arg1: index):
      "affine.for"() ({
      ^bb0(%arg2: index):
        %1 = "affine.load"(%arg0, %arg1, %arg2) {from = "A", map = #map1} : (memref<2x2xi32, #map>, index, index) -> i32
        %2 = "arith.extsi"(%1) : (i32) -> i33
        %3 = "arith.constant"() <{value = 1 : i32}> : () -> i32
        %4 = "arith.extsi"(%3) : (i32) -> i33
        %5 = "arith.addi"(%2, %4) : (i33, i33) -> i33
        %6 = "arith.trunci"(%5) : (i33) -> i32
        "affine.store"(%6, %0, %arg1, %arg2) {map = #map1, to = "B"} : (i32, memref<2x2xi32>, index, index) -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map2, step = 1 : index, upper_bound = #map3} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map2, op_name = "S_i_j_0", step = 1 : index, upper_bound = #map3} : () -> ()
    "func.return"(%0) : (memref<2x2xi32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32, #map>) -> memref<2x2xi32>, sym_name = "matrix_addi_top"}> ({
  ^bb0(%arg0: memref<2x2xi32, #map>):
    %0 = "func.call"(%arg0) <{callee = @matrix_addi}> {name = "B"} : (memref<2x2xi32, #map>) -> memref<2x2xi32>
    "func.return"(%0) : (memref<2x2xi32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32>) -> memref<2x2xi32>, sym_name = "top"}> ({
  ^bb0(%arg0: memref<2x2xi32>):
    %0 = "func.call"(%arg0) <{callee = @matrix_addi_top}> {name = "outp"} : (memref<2x2xi32>) -> memref<2x2xi32>
    "func.return"(%0) : (memref<2x2xi32>) -> ()
  }) : () -> ()
}) : () -> ()

Traceback (most recent call last):
  File "test_testing.py", line 393, in <module>
    test_nest_function_partition()
  File "test_testing.py", line 306, in test_nest_function_partition
    s.compose(s_addi_top)
  File "/home/jz2292/allo/allo/customize.py", line 77, in wrapper
    _mlir_lower_pipeline(sch.module)
  File "/home/jz2292/allo/allo/build_module.py", line 30, in _mlir_lower_pipeline
    raise e
  File "/home/jz2292/allo/allo/build_module.py", line 25, in _mlir_lower_pipeline
    mlir_pass_manager.parse(pipeline).run(module.operation)
hcl_mlir._mlir_libs.MLIRError: Failure while executing pass pipeline:
error: unknown: 'func.call' op operand type mismatch: expected operand type 'memref<2x2xi32, affine_map<(d0, d1) -> (d0, d1, 0, 0)>>', but provided 'memref<2x2xi32>' for operand number 0
 note: unknown: see current operation: %0 = "func.call"(%arg0) <{callee = @matrix_addi_top}> {name = "outp"} : (memref<2x2xi32>) -> memref<2x2xi32>

And when the user tries to fix this error and adds the code s.partition(s.inp), it would still report error:

Error: failed to run MLIR lower pipeline, printing module...
#map = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<() -> (0)>
#map3 = affine_map<() -> (2)>
"builtin.module"() ({
  "func.func"() <{function_type = (memref<2x2xi32, #map>) -> memref<2x2xi32>, sym_name = "matrix_addi"}> ({
  ^bb0(%arg0: memref<2x2xi32, #map>):
    %0 = "memref.alloc"() <{odsOperandSegmentSizes = array<i32: 0, 0>}> {name = "B"} : () -> memref<2x2xi32>
    "affine.for"() ({
    ^bb0(%arg1: index):
      "affine.for"() ({
      ^bb0(%arg2: index):
        %1 = "affine.load"(%arg0, %arg1, %arg2) {from = "A", map = #map1} : (memref<2x2xi32, #map>, index, index) -> i32
        %2 = "arith.extsi"(%1) : (i32) -> i33
        %3 = "arith.constant"() <{value = 1 : i32}> : () -> i32
        %4 = "arith.extsi"(%3) : (i32) -> i33
        %5 = "arith.addi"(%2, %4) : (i33, i33) -> i33
        %6 = "arith.trunci"(%5) : (i33) -> i32
        "affine.store"(%6, %0, %arg1, %arg2) {map = #map1, to = "B"} : (i32, memref<2x2xi32>, index, index) -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map2, step = 1 : index, upper_bound = #map3} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map2, op_name = "S_i_j_0", step = 1 : index, upper_bound = #map3} : () -> ()
    "func.return"(%0) : (memref<2x2xi32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32, #map>) -> memref<2x2xi32>, sym_name = "matrix_addi_top"}> ({
  ^bb0(%arg0: memref<2x2xi32, #map>):
    %0 = "func.call"(%arg0) <{callee = @matrix_addi}> {name = "B"} : (memref<2x2xi32, #map>) -> memref<2x2xi32>
    "func.return"(%0) : (memref<2x2xi32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32>) -> memref<2x2xi32>, sym_name = "top"}> ({
  ^bb0(%arg0: memref<2x2xi32>):
    %0 = "func.call"(%arg0) <{callee = @matrix_addi_top}> {name = "outp"} : (memref<2x2xi32>) -> memref<2x2xi32>
    "func.return"(%0) : (memref<2x2xi32>) -> ()
  }) : () -> ()
}) : () -> ()

Traceback (most recent call last):
  File "test_testing.py", line 393, in <module>
    test_nest_function_partition()
  File "test_testing.py", line 306, in test_nest_function_partition
    s.compose(s_addi_top)
  File "/home/jz2292/allo/allo/customize.py", line 77, in wrapper
    _mlir_lower_pipeline(sch.module)
  File "/home/jz2292/allo/allo/build_module.py", line 30, in _mlir_lower_pipeline
    raise e
  File "/home/jz2292/allo/allo/build_module.py", line 25, in _mlir_lower_pipeline
    mlir_pass_manager.parse(pipeline).run(module.operation)
hcl_mlir._mlir_libs.MLIRError: Failure while executing pass pipeline:
error: unknown: 'func.call' op operand type mismatch: expected operand type 'memref<2x2xi32, affine_map<(d0, d1) -> (d0, d1, 0, 0)>>', but provided 'memref<2x2xi32>' for operand number 0
 note: unknown: see current operation: %0 = "func.call"(%arg0) <{callee = @matrix_addi_top}> {name = "outp"} : (memref<2x2xi32>) -> memref<2x2xi32>

Expected behavior
This program should generate a correct output whether the user adds the code s.partition(s.inp) or not.

[Feature] Allow list of AlloType as dtype for Structure

Is your feature request related to a problem? Please describe.
Yes. Currently, Allo's Structure type asserts if the dtype input is of AlloType. This causes the scenario where if the input type is a list such as uint8[16], there is an assertion failure.

Something like this works

t = Structure({
    "uint": uint8
})

but the following doesn't

t1 = Structure({
    "uint list": uint8[16]
})

Describe the solution you'd like
List of AlloType should also be allowed as dtype when creating a type using Structure

[BUG] `opt-level` setting segfaults

Describe the bug

When we call memref.shape(), it will cause a segfault but the generated IR is correct. Niansong has documented this issue in this.

Now, we concluded that when opt-level = 0,1,2, IR with memref.shape can execute correctly. On the other hand, when opt-level = 0,1, there will be some numerical errors in different types of calculation.

Solution

As the above, we now set opt-level = 2 as default.

execution_engine = ExecutionEngine(
            lowered, opt_level=2, shared_libs=shared_libs)

Examples

memref.shape()

module {
  memref.global "private" constant @const_0 : memref<3xi64> = dense<[5, 2, 4]>
  memref.global "private" constant @const_1 : memref<2xi64> = dense<[5, 8]>
  func.func @kernel(%arg0: memref<5x3x2xf32>, %arg1: memref<4x3xf32>, %arg2: memref<4xf32>) -> memref<5x8xf32> attributes {itypes = "___", otypes = "_"} {
    %c1_i32 = arith.constant 1 : i32
    %0 = arith.sitofp %c1_i32 : i32 to f32
    %1 = arith.negf %0 : f32
    %c2_i32 = arith.constant 2 : i32
    %2 = arith.sitofp %c2_i32 : i32 to f32
    %3 = arith.negf %2 : f32
    %alloc = memref.alloc() {name = "output1"} : memref<5x2x3xf32>
    %c0_i32 = arith.constant 0 : i32
    %4 = arith.sitofp %c0_i32 : i32 to f32
    linalg.fill {op_name = "transpose_init_zero_0"} ins(%4 : f32) outs(%alloc : memref<5x2x3xf32>)
    linalg.transpose ins(%arg0 : memref<5x3x2xf32>) outs(%alloc : memref<5x2x3xf32>) permutation = [0, 2, 1]  {op_name = "transpose_1"}
    %alloc_0 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_1 = arith.constant 0 : i32
    %5 = arith.sitofp %c0_i32_1 : i32 to f32
    linalg.fill {op_name = "linear_init_zero_2"} ins(%5 : f32) outs(%alloc_0 : memref<5x2x4xf32>)
    %alloc_2 = memref.alloc() : memref<3x4xf32>
    %c0_i32_3 = arith.constant 0 : i32
    %6 = arith.sitofp %c0_i32_3 : i32 to f32
    linalg.fill {op_name = "transpose_init_zero_3"} ins(%6 : f32) outs(%alloc_2 : memref<3x4xf32>)
    linalg.transpose ins(%arg1 : memref<4x3xf32>) outs(%alloc_2 : memref<3x4xf32>) permutation = [1, 0]  {op_name = "transpose_4"}
    %alloc_4 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_5 = arith.constant 0 : i32
    %7 = arith.sitofp %c0_i32_5 : i32 to f32
    linalg.fill {op_name = "matmul_init_zero_5"} ins(%7 : f32) outs(%alloc_4 : memref<5x2x4xf32>)
    %alloc_6 = memref.alloc() : memref<5x3x4xf32>
    linalg.broadcast ins(%alloc_2 : memref<3x4xf32>) outs(%alloc_6 : memref<5x3x4xf32>) dimensions = [0] 
    %alloc_7 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_8 = arith.constant 0 : i32
    %8 = arith.sitofp %c0_i32_8 : i32 to f32
    linalg.fill {op_name = "bmm_init_zero_6"} ins(%8 : f32) outs(%alloc_7 : memref<5x2x4xf32>)
    linalg.batch_matmul {op_name = "bmm_7"} ins(%alloc, %alloc_6 : memref<5x2x3xf32>, memref<5x3x4xf32>) outs(%alloc_7 : memref<5x2x4xf32>)
    %alloc_9 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_10 = arith.constant 0 : i32
    %9 = arith.sitofp %c0_i32_10 : i32 to f32
    linalg.fill {op_name = "view_init_zero_8"} ins(%9 : f32) outs(%alloc_9 : memref<5x2x4xf32>)
    %10 = memref.get_global @const_0 : memref<3xi64>
    %reshape = memref.reshape %alloc_7(%10) : (memref<5x2x4xf32>, memref<3xi64>) -> memref<5x2x4xf32>
    %alloc_11 = memref.alloc() : memref<5x2x4xf32>
    linalg.broadcast ins(%arg2 : memref<4xf32>) outs(%alloc_11 : memref<5x2x4xf32>) dimensions = [0, 1] 
    %alloc_12 = memref.alloc() {name = "output2"} : memref<5x2x4xf32>
    %c0_i32_13 = arith.constant 0 : i32
    %11 = arith.sitofp %c0_i32_13 : i32 to f32
    linalg.fill {op_name = "add_init_zero_9"} ins(%11 : f32) outs(%alloc_12 : memref<5x2x4xf32>)
    linalg.add {op_name = "add_10"} ins(%reshape, %alloc_11 : memref<5x2x4xf32>, memref<5x2x4xf32>) outs(%alloc_12 : memref<5x2x4xf32>)
    %alloc_14 = memref.alloc() : memref<5x8xf32>
    %c0_i32_15 = arith.constant 0 : i32
    %12 = arith.sitofp %c0_i32_15 : i32 to f32
    linalg.fill {op_name = "view_init_zero_11"} ins(%12 : f32) outs(%alloc_14 : memref<5x8xf32>)
    %13 = memref.get_global @const_1 : memref<2xi64>
    %reshape_16 = memref.reshape %alloc_12(%13) {name = "output"} : (memref<5x2x4xf32>, memref<2xi64>) -> memref<5x8xf32>
    return %reshape_16 : memref<5x8xf32>
  }


  func.func @main() {
	%arg0 = memref.alloc() : memref<5x3x2xf32>
	%arg1 = memref.alloc() : memref<4x3xf32>
	%arg2 = memref.alloc() : memref<4xf32>
	%arg3 = func.call @kernel(%arg0, %arg1, %arg2) : (memref<5x3x2xf32>, memref<4x3xf32>, memref<4xf32>) -> memref<5x8xf32>
	return
  }

}

stack trace

 #3 0x00007f2186a5d950 llvm::isPotentiallyReachable(llvm::Instruction const*, llvm::Instruction const*, llvm::SmallPtrSetImpl<llvm::BasicBlock*> const*, llvm::DominatorTree const*, llvm::LoopInfo const*) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a21950)
 #4 0x00007f2186a2d619 llvm::EarliestEscapeInfo::isNotCapturedBeforeOrAt(llvm::Value const*, llvm::Instruction const*) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59f1619)
 #5 0x00007f2186a276bb llvm::BasicAAResult::getModRefInfo(llvm::CallBase const*, llvm::MemoryLocation const&, llvm::AAQueryInfo&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59eb6bb)
 #6 0x00007f2186a0756b llvm::AAResults::getModRefInfo(llvm::CallBase const*, llvm::MemoryLocation const&, llvm::AAQueryInfo&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59cb56b)
 #7 0x00007f2186a08e51 llvm::AAResults::getModRefInfo(llvm::Instruction const*, std::optional<llvm::MemoryLocation> const&, llvm::AAQueryInfo&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59cce51)
 #8 0x00007f218638ebdf (anonymous namespace)::DSEState::isReadClobber(llvm::MemoryLocation const&, llvm::Instruction*) DeadStoreElimination.cpp:0:0
 #9 0x00007f2186398eae (anonymous namespace)::DSEState::getDomMemoryDef(llvm::MemoryDef*, llvm::MemoryAccess*, llvm::MemoryLocation const&, llvm::Value const*, unsigned int&, unsigned int&, bool, unsigned int&) DeadStoreElimination.cpp:0:0
#10 0x00007f218639af52 (anonymous namespace)::eliminateDeadStores(llvm::Function&, llvm::AAResults&, llvm::MemorySSA&, llvm::DominatorTree&, llvm::PostDominatorTree&, llvm::AssumptionCache&, llvm::TargetLibraryInfo const&, llvm::LoopInfo const&) DeadStoreElimination.cpp:0:0
#11 0x00007f218639d038 llvm::DSEPass::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5361038)
#12 0x00007f218554e35e llvm::detail::PassModel<llvm::Function, llvm::DSEPass, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x451235e)
#13 0x00007f218737d514 llvm::PassManager<llvm::Function, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x6341514)
#14 0x00007f21855472ce llvm::detail::PassModel<llvm::Function, llvm::PassManager<llvm::Function, llvm::AnalysisManager<llvm::Function>>, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x450b2ce)
#15 0x00007f2186a7739f llvm::CGSCCToFunctionPassAdaptor::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a3b39f)
#16 0x00007f218554d40e llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::CGSCCToFunctionPassAdaptor, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x451140e)
#17 0x00007f2186a6ff5b llvm::PassManager<llvm::LazyCallGraph::SCC, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a33f5b)
#18 0x00007f218554d3ce llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::PassManager<llvm::LazyCallGraph::SCC, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x45113ce)
#19 0x00007f2186a73a55 llvm::DevirtSCCRepeatedPass::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a37a55)
#20 0x00007f218554d3ee llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::DevirtSCCRepeatedPass, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x45113ee)
#21 0x00007f2186a71bf9 llvm::ModuleToPostOrderCGSCCPassAdaptor::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a35bf9)
#22 0x00007f21857956ef llvm::ModuleInlinerWrapperPass::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x47596ef)
#23 0x00007f218554cf8e llvm::detail::PassModel<llvm::Module, llvm::ModuleInlinerWrapperPass, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Module>>::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x4510f8e)
#24 0x00007f2185542f50 mlir::makeOptimizingTransformer(unsigned int, unsigned int, llvm::TargetMachine*)::'lambda'(llvm::Module*)::operator()(llvm::Module*) const OptUtils.cpp:0:0
#25 0x00007f2185543cad std::_Function_handler<llvm::Error (llvm::Module*), mlir::makeOptimizingTransformer(unsigned int, unsigned int, llvm::TargetMachine*)::'lambda'(llvm::Module*)>::_M_invoke(std::_Any_data const&, llvm::Module*&&) OptUtils.cpp:0:0
#26 0x00007f218264550d llvm::Error llvm::function_ref<llvm::Error (llvm::Module*)>::callback_fn<std::function<llvm::Error (llvm::Module*)>>(long, llvm::Module*) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x160950d)
#27 0x00007f218306e1e6 mlir::ExecutionEngine::create(mlir::Operation*, mlir::ExecutionEngineOptions const&, std::unique_ptr<llvm::TargetMachine, std::default_delete<llvm::TargetMachine>>) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x20321e6)
#28 0x00007f2182646ab5 mlirExecutionEngineCreate (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x160aab5)
#29 0x00007f21803de5e5 pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)::operator()(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool) const /work/shared/users/common/llvm-project-18.x/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp:82:77
#30 0x00007f21803df91a void pybind11::detail::initimpl::factory<pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type (*)(), (anonymous namespace)::PyExecutionEngine* (MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<(anonymous namespace)::PyExecutionEngine>, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [327]>(pybind11::class_<(anonymous namespace)::PyExecutionEngine>&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [327]) &&::'lambda'(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)::operator()(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool) const /home/nz264/anaconda3/envs/mlir/lib/python3.8/site-packages/pybind11/include/pybind11/detail/init.h:242:29
#31 0x00007f21803e37e3 pybind11::class_<(anonymous namespace)::PyExecutionEngine> pybind11::detail::argument_loader<pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool>::call_impl<void, void pybind11::detail::initimpl::factory<pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type (*)(), (anonymous namespace)::PyExecutionEngine* (MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<(anonymous namespace)::PyExecutionEngine>, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [327]>(pybind11::class_<(anonymous namespace)::PyExecutionEngine>&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [327]) &&::'lambda'(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)&, 0ul, 1ul, 2ul, 3ul, 4ul, pybind11::detail::void_type>(void pybind11::detail::initimpl::factory<pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type (*)(), (anonymous namespace)::PyExecutionEngine* (MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<(anonymous namespace)::PyExecutionEngine>, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [327]>(pybind11::class_<(anonymous namespace)::PyExecutionEngine>&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [327]) &&::'lambda'(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)&, std::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul, 4ul>, pybind11::detail::void_type&&) && /home/nz264/anaconda3/envs/mlir/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1205:91
#32 0x00007f21803e3439 _ZNO8pybind116detail15argument_loaderIJRNS0_16value_and_holderE10MlirModuleiRKSt6vectorISsSaISsEEbEE4callIvNS0_9void_typeERZNOS0_8initimpl7factoryIZL34pybind11_init__mlirExecutionEngineRNS_7module_EEUlS4_iS9_bE_PFSC_vEFPN12_GLOBAL__N_117PyExecutionEngineES4_iS9_bESI_E7executeINS_6class_ISL_JEEEJNS_3argENS_5arg_vEST_ST_A327_cEEEvRT_DpRKT0_EUlS3_S4_iS9_bE_EENSt9enable_ifIXsrSt7is_voidISV_E5valueESC_E4typeEOT1_ /home/nz264/anaconda3/envs/mlir/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1183:26

numerical error

when opt-level=0:

    def test_compare_int_float():
        Ty = Int(5)
    
        def kernel(A: Ty) -> Ty:
            B: Ty = 0
            if A > B or A + 1 < 0.0:
                B = A
            return B
    
        s = allo.customize(kernel)
        mod = s.build()
        assert mod(2) == kernel(2)
>       assert mod(-3) == kernel(-3)
E       assert 29 == -3
E        +  where 29 = <allo.backend.llvm.LLVMModule object at 0x7f8d21573c70>(-3)
E        +  and   -3 = <function test_compare_int_float.<locals>.kernel at 0x7f8d21d588b0>(-3)

[BUG] Loop carried dependences should be SSA values not memory operations

Describe the bug
Neither writing kernels with primitives like matmul() or using allo.grid() make use of affine.for's ability to contain iteration arguments. For us, this is important for pipelining. Here is an example of the MLIR produced by test_reduce() (shown further down).

module {
  func.func @kernel(%arg0: memref<20xi32>) -> i32 attributes {itypes = "s", otypes = "s", top} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "sum"} : memref<1xi32>
    affine.store %c0_i32, %alloc[0] {to = "sum"} : memref<1xi32>
    affine.for %arg1 = 0 to 20 {
      %1 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
      // We have to reload the value
     // ... when it should be forwarded from last iteration's store
      %2 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
      %3 = arith.addi %2, %1 : i32
      affine.store %3, %alloc[0] {to = "sum"} : memref<1xi32>
    } {loop_name = "i", op_name = "S_i_0", reduction}
    %0 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
    return %0 : i32
  }
}

To Reproduce
The linalg dialect compounds the issue, because it lowers linalg to affine loops without an accumulator:

def test_linalg_matmul():
    N = 16
    from allo import matmul

    def kernel(A: int32[N, N], B: int32[N, N]) -> int32[N, N]:
        return matmul(A, B)

    s = allo.customize(kernel)
    print(s.module)

But even with an explicit accumulator in a single memref cell, I can't get it to be raised to SSA values:

def test_reduce():
    N = 20

    def kernel(A: int32[N]) -> int32:
        sum: int32 = 0
        for i in allo.reduction(N):
            sum += A[i]
        return sum

    s = allo.customize(kernel)
    print(s.module)

Buggy output
I was not hopeful that the existing MLIR passes would help with this issue, but I tried anyways by running mlir-opt --convert-linalg-to-affine-loops --affine-scalrep --lower-affine --convert-scf-to-cf --mem2reg

It is only expected to work on unstructured control flow, but I could not get it to work for that.

Expected behavior
Here is an example of how we do matmul in affine that uses iteration arguments to assist the pipelining pass:

  affine.for %arg3 = 0 to 16 {
      affine.for %arg4 = 0 to 16 {
        %sum = affine.for %arg5 = 0 to 16 
                iter_args(%sum_iter = %c0_i32) -> (i32) {
          %2 = affine.load %A[%arg3, %arg5] : memref<16x16xi32>
          %3 = affine.load %B[%arg5, %arg4] : memref<16x16xi32>
          %4 = arith.muli %2, %3 : i32
          %sum_next = arith.addi %4, %sum_iter : i32
          affine.yield %sum_next : i32
        }
        affine.store %sum, %C[%arg3, %arg4] : memref<16x16xi32>
      }
    }

Perhaps there are the right patterns/passes in MLIR to accomplish what we want, but I haven't found them yet. Maybe we will have to write our own pass for this or lower the AST differently.

[BUG] Vitis HLS backend results not written back to argument

Describe the bug
Vitis HLS backend generates buggy code, the results are not written back to the arguments.

To Reproduce

import allo
from allo.ir.types import int32

N = 256

def compute(
    x: int32[N],
    y: int32[N]
):
    for i in range(N):
        y[i] = x[i]

s = allo.customize(compute)
s.build(target="vitis_hls", mode="csim", project="test.prj")

Buggy output

//===------------------------------------------------------------*- C++ -*-===//
//
// Automatically generated file for High-level Synthesis (HLS).
//
//===----------------------------------------------------------------------===//
#include <algorithm>
#include <ap_axi_sdata.h>
#include <ap_fixed.h>
#include <ap_int.h>
#include <hls_math.h>
#include <hls_stream.h>
#include <math.h>
#include <stdint.h>
using namespace std;

extern "C" {

void compute(
  int32_t *v0,
  int32_t *v1
) {	// L2
  #pragma HLS interface m_axi port=v0 offset=slave bundle=gmem0
  #pragma HLS interface m_axi port=v1 offset=slave bundle=gmem1
  int32_t buf0[256];	//
  l_S_buf0_buf0_l_0: for (int buf0_l_0 = 0; buf0_l_0 < 256; buf0_l_0++) {	//
  #pragma HLS pipeline II=1 rewind
    int32_t v4 = v0[buf0_l_0];	//
    buf0[buf0_l_0] = v4;	//
  }
  int32_t buf1[256];	//
  l_S_buf1_buf1_l_0: for (int buf1_l_0 = 0; buf1_l_0 < 256; buf1_l_0++) {	//
  #pragma HLS pipeline II=1 rewind
    int32_t v7 = v1[buf1_l_0];	//
    buf1[buf1_l_0] = v7;	//
  }
  l_S_i_0_i: for (int i = 0; i < 256; i++) {	// L3
    int32_t v9 = buf0[i];	// L4
    buf1[i] = v9;	// L5
  }
}


} // extern "C"

buf1 is the result buffer, but it was not written back to v1, therefore the entire kernel is considered dead code.

Expected behavior
Result tensor passed in as arguments should be written back to

Additional context
This is related to the interface requirement of Vitis, needs further consideration not to affect existing systolic array examples

[Feature] Use wraparound math for intermediate vals when desired

Is your feature request related to a problem? Please describe.
Yes, arbitrary precision types for intermediate values are causing hitches in the AMC flow. For example, in #121 the typecasting was causing extra buffers and copy loops in the IR which increased the latency of designs. In this latest case, I am trying to evaluate a polynomial. The majority of the IR becomes dedicated to extending numbers in order to create arith ops that are large enough to hold overflow bits. It would be nice if we could toggle this behavior to simply use wraparound when both the lval and rval types match.

Here is an example of the current behavior:

def test_tay_approximation(printResEst=False):
    N = 16

    def kernel_approx(x: int32[N]) -> int32[N]:
        y: int32[N] = 0
        for i in allo.grid(N):
            y[i] = (
                x[i] * 10000
                - (x[i] * x[i] * x[i] * 17) * 10
                - (x[i] * x[i] * x[i] * x[i] * x[i] * 11)
            )
        return y

    s = allo.customize(kernel_approx)
    print(s.module)

The output is

module {
  func.func @kernel_approx(%arg0: memref<16xi32>) -> memref<16xi32> attributes {itypes = "s", otypes = "s"} {
    %alloc = memref.alloc() {name = "y"} : memref<16xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<16xi32>)
    affine.for %arg1 = 0 to 16 {
      %0 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %1 = arith.extsi %0 : i32 to i64
      %c10000_i32 = arith.constant 10000 : i32
      %2 = arith.extsi %c10000_i32 : i32 to i64
      %3 = arith.muli %1, %2 : i64
      %4 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %5 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %6 = arith.extsi %4 : i32 to i64
      %7 = arith.extsi %5 : i32 to i64
      %8 = arith.muli %6, %7 : i64
      %9 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %10 = arith.extsi %8 : i64 to i96
      %11 = arith.extsi %9 : i32 to i96
      %12 = arith.muli %10, %11 : i96
      %13 = arith.extsi %12 : i96 to i128
      %c17_i32 = arith.constant 17 : i32
      %14 = arith.extsi %c17_i32 : i32 to i128
      %15 = arith.muli %13, %14 : i128
      %16 = arith.extsi %15 : i128 to i160
      %c10_i32 = arith.constant 10 : i32
      %17 = arith.extsi %c10_i32 : i32 to i160
      %18 = arith.muli %16, %17 : i160
      %19 = arith.extsi %3 : i64 to i161
      %20 = arith.extsi %18 : i160 to i161
      %21 = arith.subi %19, %20 : i161
      %22 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %23 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %24 = arith.extsi %22 : i32 to i64
      %25 = arith.extsi %23 : i32 to i64
      %26 = arith.muli %24, %25 : i64
      %27 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %28 = arith.extsi %26 : i64 to i96
      %29 = arith.extsi %27 : i32 to i96
      %30 = arith.muli %28, %29 : i96
      %31 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %32 = arith.extsi %30 : i96 to i128
      %33 = arith.extsi %31 : i32 to i128
      %34 = arith.muli %32, %33 : i128
      %35 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %36 = arith.extsi %34 : i128 to i160
      %37 = arith.extsi %35 : i32 to i160
      %38 = arith.muli %36, %37 : i160
      %39 = arith.extsi %38 : i160 to i192
      %c11_i32 = arith.constant 11 : i32
      %40 = arith.extsi %c11_i32 : i32 to i192
      %41 = arith.muli %39, %40 : i192
      %42 = arith.extsi %21 : i161 to i193
      %43 = arith.extsi %41 : i192 to i193
      %44 = arith.subi %42, %43 : i193
      %45 = arith.trunci %44 : i193 to i32
      affine.store %45, %alloc[%arg1] {to = "y"} : memref<16xi32>
    } {loop_name = "i", op_name = "S_i_0"}
    return %alloc : memref<16xi32>
  }
}

Describe the solution you'd like
In the end, I would expect Vivado to be able to optimize away a lot of these wires. However, toggling this behavior has some practical benefits for us: reduces the number of nets in waveform dumps for debugging, Calyx has some trouble supporting ap types wider than 64b.

[BUG] `arith.fptoui` issue with negative loop bound

Describe the bug
Negative loop bound fails at IR building stage.

To Reproduce

import allo
from allo.ir.types import int32

N = 256

def compute(
    x: int32[N],
    y: int32[N]
):
    for i in range(N-1, -1, -1):
        y[i+1] = x[i+1]


s = allo.customize(compute)
print(s.module)

Buggy output

Traceback (most recent call last):
  File "/work/shared/users/phd/nz264/machsuite-allo/viterbi/viterbi.py", line 63, in <module>
    mod = s.build()
          ^^^^^^^^^
  File "/work/shared/users/phd/nz264/allo/allo/customize.py", line 682, in build
    return LLVMModule(
           ^^^^^^^^^^^
  File "/work/shared/users/phd/nz264/allo/allo/backend/llvm.py", line 59, in __init__
    self.module = Module.parse(str(mod), ctx)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
hcl_mlir._mlir_libs._site_initialize.<locals>.MLIRError: Unable to parse module assembly:
error: "-":154:11: 'arith.fptoui' op result #0 must be signless-fixed-width-integer-like, but got 'index'
 note: "-":154:11: see current operation: %30 = "arith.fptoui"(%25) : (f32) -> index

[BUG] KeyError: 'attempt to access a non-existent attribute'

An error in the buffer generation pass.

import allo
from allo.ir.types import int32

def func(A: int32[8]) -> int32[8]:
    return A

def top(A: int32[8]) -> int32[8]:
    return func(A)

s = allo.customize(top)
print(s.module)
s.build(target="vitis_hls", mode="csim")

This will lead to the following error:

Traceback (most recent call last):
  File "/scratch/users/hc676/allo/test.py", line 12, in <module>
    s.build(target="vitis_hls", mode="csim")
  File "/scratch/users/hc676/allo/allo/customize.py", line 685, in build
    return HLSModule(
           ^^^^^^^^^^
  File "/scratch/users/hc676/allo/allo/backend/hls.py", line 175, in __init__
    buffers = generate_input_output_buffers(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/users/hc676/allo/allo/passes.py", line 159, in generate_input_output_buffers
    MockBuffer(top_func_name, arg.owner.attributes["name"].value)
                              ~~~~~~~~~~~~~~~~~~~~^^^^^^^^
KeyError: 'attempt to access a non-existent attribute'

[BUG]Unused memory allocation when not enable_tensor

Describe the bug
When enable_tensor=False, redundant memref.allocs are generated. These memory blocks are allocated and initialized, but never used again.

To Reproduce
In allo/examples/bert_layer.py, we set enable_tensor=False, monitor_memref=True. Function get_mem_usage generates file bert_layer_mem_usage.txt.

Buggy output
Provide the output of the buggy program.
The beginning of the contents is as follows:

+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
|    name    |       shape       | dtype | mem(bits)  |     BRAM(18K)      | store counts |                                     data storage                                     |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
|   %alloc   |   [2, 512, 768]   |  f32  |  25165824  |       1398.1       |      0       |                                                                                      |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
|  %alloc_7  |   [2, 768, 768]   |  f32  |  37748736  |      2097.15       |      1       |            %30 = memref.load %alloc_1[%arg2, %arg3] : memref<768x768xf32>            |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
|  %alloc_2  |      [2, 512]     |  f32  |   32768    |        1.82        |      0       |                                                                                      |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
|  %alloc_3  |      [2, 512]     |  f32  |   32768    |        1.82        |      1       |                             %5 = arith.addf %4, %2 : f32                             |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
|  %alloc_1  |     [768, 768]    |  f32  |  18874368  |      1048.58       |      1       |               %30 = memref.load %0[%arg1, %arg2] : memref<768x768xf32>               |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
| %alloc_11  |   [2, 512, 768]   |  f32  |  25165824  |       1398.1       |      1       |                           %34 = arith.addf %32, %33 : f32                            |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
| %alloc_18  |   [2, 512, 768]   |  f32  |  25165824  |       1398.1       |      1       |                    %30 = memref.load %1[%arg3] : memref<768xf32>                     |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
| %alloc_22  |   [2, 512, 768]   |  f32  |  25165824  |       1398.1       |      1       |                           %32 = arith.addf %30, %31 : f32                            |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
| %alloc_29  |  [2, 512, 12, 64] |  f32  |  25165824  |       1398.1       |      0       |                                                                                      |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
| %alloc_33  |  [2, 12, 512, 64] |  f32  |  25165824  |       1398.1       |      1       |   %30 = memref.load %reshape[%arg1, %arg2, %arg3, %arg4] : memref<2x512x12x64xf32>   |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
| %alloc_40  |   [2, 512, 768]   |  f32  |  25165824  |       1398.1       |      0       |                                                                                      |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
| %alloc_44  |     [768, 768]    |  f32  |  18874368  |      1048.58       |      1       |               %30 = memref.load %2[%arg1, %arg2] : memref<768x768xf32>               |
+------------+-------------------+-------+------------+--------------------+--------------+--------------------------------------------------------------------------------------+
| %alloc_51  |   [2, 768, 768]   |  f32  |  37748736  |      2097.15       |      1       |           %30 = memref.load %alloc_44[%arg2, %arg3] : memref<768x768xf32>            |

%alloc, %alloc_2, %alloc_29, %alloc_40 are allocated but never used. There are many other allocation with this problem in the total 72 memref.allocs generated from this example.

Expected behavior
These allocations should not exist and the value in the column "store counts" should be at least 1.

[BUG] Excessive Copy Loops

Describe the bug
Excessive copy loops are created due to data type conversion of tensors expressed in the linalg dialect.

To Reproduce

def test_vadd():
    from allo import add

    def kernel(A: uint32[N], B: uint32[N]) -> uint32[N]:
        return A + B

    s = allo.customize(kernel)
    print(s.module)

Buggy output

#map = affine_map<(d0) -> (d0)>
module {
  func.func @kernel(%arg0: memref<20xi32>, %arg1: memref<20xi32>) -> memref<20xi32> attributes {itypes = "uu", otypes = "u"} {
    %alloc = memref.alloc() {unsigned} : memref<20xi33>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : memref<20xi32>) outs(%alloc : memref<20xi33>) {
    ^bb0(%in: i32, %out: i33):
      %0 = arith.extui %in : i32 to i33
      linalg.yield %0 : i33
    }
    %alloc_0 = memref.alloc() {unsigned} : memref<20xi33>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg1 : memref<20xi32>) outs(%alloc_0 : memref<20xi33>) {
    ^bb0(%in: i32, %out: i33):
      %0 = arith.extui %in : i32 to i33
      linalg.yield %0 : i33
    }
    %alloc_1 = memref.alloc() : memref<20xi33>
    linalg.add {op_name = "add_0"} ins(%alloc, %alloc_0 : memref<20xi33>, memref<20xi33>) outs(%alloc_1 : memref<20xi33>)
    %alloc_2 = memref.alloc() {unsigned} : memref<20xi32>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%alloc_1 : memref<20xi33>) outs(%alloc_2 : memref<20xi32>) {
    ^bb0(%in: i33, %out: i32):
      %0 = arith.trunci %in : i33 to i32
      linalg.yield %0 : i32
    }
    return %alloc_2 : memref<20xi32>
  }
}

In short, when this is lowered to affine it manifests in excessive copying in the beginning of the program, and our AMC flow is very sensitive to this.

What really should occur is noticing that value that addition is bound to is the same as the input type. So just make add (i32, i32) -> i32 with normal wraparound.

The compiler couldn't identify the loop listed before

Describe the bug
When running a transformation buffer_at, the compiler cannot find the band of loop j. But I indeed defined the loop j before.

To Reproduce
It happened after customizing the kernel.
` def kernel_md[T :(float32,int32), M: int32, N:int32
](position_x: "T[M]", position_y: "T[M]",
position_z: "T[M]", NL:"T[M, N]",
force_x:"T[M]", fx:"T[M]", delx:"T[N]",
dely:"T[N]", delz:"T[N]", r2inv:"T[N]",
r6inv:"T[N]"):

    for i in range(M):
        for j in range(N):
            jidx:int32 = NL[i ,j]
            delx[j] = position_x[i] - position_x[jidx]
            dely[j] = position_y[i] - position_y[jidx]
            delz[j] = position_z[i] - position_z[jidx]
            if((delx[j] * delx[j] + dely[j] * dely[j] + delz[j] * delz[j])==0):
                r2inv[j]=(domainEdge*domainEdge*3.0)*1000   
            else:
                r2inv[j] = 1.0 / (delx[j] * delx[j] + dely[j] * dely[j] + delz[j] * delz[j])
            r6inv[j] = r2inv[j] * r2inv[j] * r2inv[j]
            fx[i] = fx[i]+delx[j] * r2inv[j] * r6inv[j]*(lj1*r6inv[j]-lj2)
        force_x[i] = fx[i]
sch0 =allo.customize(kernel_md, instantiate=[concrete_type, m,n])
print(sch0.module)
sch0.split("i",factor=8)
sch0.split("j", factor=8)
sch0.buffer_at(sch0.force_x, axis="j")`

Buggy output
Traceback (most recent call last): File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/examples/polybench/md_2_knn.py", line 60, in <module> mod_test =md(float32, M,N) ^^^^^^^^^^^^^^^^ File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/examples/polybench/md_2_knn.py", line 47, in md sch0.buffer_at(sch0.force_x, axis="j") File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/allo/customize.py", line 110, in wrapper res = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/allo/customize.py", line 429, in buffer_at band_name, axis = find_loop_in_bands(func, axis) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/allo/ir/transform.py", line 104, in find_loop_in_bands raise RuntimeError(f"Cannot find the band of loop {axis_name}") RuntimeError: Cannot find the band of loop j

Expected behavior
The Allo should automatically create an intermediate buffer for force_x and attach it inside the loop j.

[BUG] Tuple not supported in dataflow analysis

Describe the bug
Issue with use-def analysis for functions with multiple return values.

To Reproduce

def callee(a: float32, b: float32) -> (float32, float32):
    c: float32 = a + b
    d: float32 = a - b
    return c, d

def kernel(A: float32[10], B: float32[10]) -> (float32[10], float32[10]):
    C: float32[10] = 0
    D: float32[10] = 0
    for i in range(10):
        C[i], D[i] = callee(A[i], B[i])
    return C, D

s = allo.customize(kernel)
print(s.module)

Buggy output

Traceback (most recent call last):
  File "test_builder.py", line 556, in <module>
    non_tensor_args()
  File "test_builder.py", line 545, in non_tensor_args
    s = allo.customize(kernel)
  File "/work/shared/users/phd/nz264/allo/allo/customize.py", line 826, in customize
    use_def_chain.visit(tree)
  File "/home/nz264/anaconda3/envs/mlir/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/work/shared/users/phd/nz264/allo/allo/ir/use_def.py", line 282, in visit_Module
    res.append(self.visit(stmt))
  File "/home/nz264/anaconda3/envs/mlir/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/work/shared/users/phd/nz264/allo/allo/ir/use_def.py", line 270, in visit_FunctionDef
    res.append(self.visit(stmt))
  File "/home/nz264/anaconda3/envs/mlir/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/work/shared/users/phd/nz264/allo/allo/ir/use_def.py", line 137, in visit_For
    res.append(self.visit(stmt))
  File "/home/nz264/anaconda3/envs/mlir/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/work/shared/users/phd/nz264/allo/allo/ir/use_def.py", line 212, in visit_Assign
    name = get_name(node.targets[0])
  File "/work/shared/users/phd/nz264/allo/allo/ir/use_def.py", line 210, in get_name
    return get_name(subnode.value)
AttributeError: 'Tuple' object has no attribute 'value'

[BUG] Relu allocates extra memory for storing broadcasted zero tensor

Describe the bug
allo.relu allocates an extra memory block for storing the zero tensor to which the input is compared.

To Reproduce
Run mlp.py without enable tensor and print the intermediate_module. The last part of output is

%alloc_50 = memref.alloc() : memref<30x30xf32>
    %c0_51 = arith.constant 0 : index
    %c30_52 = arith.constant 30 : index
    %c1_53 = arith.constant 1 : index
    scf.for %arg1 = %c0_51 to %c30_52 step %c1_53 {
      %c0_57 = arith.constant 0 : index
      %c30_58 = arith.constant 30 : index
      %c1_59 = arith.constant 1 : index
      scf.for %arg2 = %c0_57 to %c30_58 step %c1_59 {
        memref.store %cst, %alloc_50[%arg1, %arg2] : memref<30x30xf32>
      }
    }
    %c0_54 = arith.constant 0 : index
    %c30_55 = arith.constant 30 : index
    %c1_56 = arith.constant 1 : index
    scf.for %arg1 = %c0_54 to %c30_55 step %c1_56 {
      %c0_57 = arith.constant 0 : index
      %c30_58 = arith.constant 30 : index
      %c1_59 = arith.constant 1 : index
      scf.for %arg2 = %c0_57 to %c30_58 step %c1_59 {
        %4 = memref.load %alloc_39[%arg1, %arg2] : memref<30x30xf32>
        %5 = memref.load %alloc_50[%arg1, %arg2] : memref<30x30xf32>
        %6 = arith.maxf %4, %5 : f32
        memref.store %6, %alloc_46[%arg1, %arg2] : memref<30x30xf32>
      }
    }

%cst, a 0 constant, is broadcasted into %alloc_50. This allocation is not needed.

Expected behavior
Currently added feature allo.max to dsl.py:

def max(x, y):
    return np.maximum(x, y)

and added attribute max to builder.py:

 if isinstance(arg_type, (F32Type, IntegerType)):
                opcls = {
                    "exp": math_d.ExpOp,
                    "log": math_d.LogOp,
                    "log2": math_d.Log2Op,
                    "log10": math_d.Log10Op,
                    "sqrt": math_d.SqrtOp,
                    "sin": math_d.SinOp,
                    "cos": math_d.CosOp,
                    "tan": math_d.TanOp,
                    "tanh": math_d.TanhOp,
                    "power": math_d.PowFOp,
                    "max": arith_d.MaxFOp,
                }.get(fn_name)

When implementing relu using for loops, there is no overhead.

for i, j in allo.grid(30, 30):
            C[i, j] = allo.max(C[i, j], 0.0)

[BUG] FIFO read in loop attempts to access repeated data

This is an issue I got while turning PolyBench 2mm to a dataflow design. This issue can be fixed later.

Describe the bug
There should be a loop-invariant code motion pass before turning load/store into fifo read/write operations. When a read happens in a loop that doesn't affect the load indices, turning it directly to fifo read can result in accessing the wrong data.

To Reproduce

     def mm1(A: T[P, Q], B: T[Q, R], out_AB: T[P, R]):
		for i0, j0 in allo.grid(P, R):
			for k0 in allo.reduction(Q):
				out_AB[i0, j0] += A[i0, k0] * B[k0, j0]

	def mm2(out_AB: T[P, R], C: T[R, S], out_ABC: T[P, S]):
		for i1, j1 in allo.grid(P, S):
			for k1 in allo.reduction(R):
				out_ABC[i1, j1] += out_AB[i1, k1] * C[k1, j1]

	def ele_add(out_ABC: T[P, S], D: T[P, S], output: T[P, S]):
		for i2, j2 in allo.grid(P, S):
			output[i2, j2] = out_ABC[i2, j2] * beta + D[i2, j2] * alpha

	def kernel_2mm(A: T[P, Q], B: T[Q, R], C: T[R, S], D: T[P, S]) -> T[P, S]:
		out_AB: T[P, R]
		out_ABC: T[P, S]
		output: T[P, S]
		mm1(A, B, out_AB)
		mm2(out_AB, C, out_ABC)
		ele_add(out_ABC, D, output)
		return output

	sch0 = allo.customize(mm1)
	sch0.reorder("k0", "j0")
	sch0.buffer_at(sch0.out_AB, axis="i0")
	sch0.pipeline("k0")
	
	sch1 = allo.customize(mm2)
	sch1.reorder("k1", "j1")
	sch1.buffer_at(sch1.out_ABC, axis="i1")
	sch1.pipeline("k1")
	
	sch2 = allo.customize(ele_add)
	sch2.pipeline("j2")

	sch = allo.customize(kernel_2mm)
	sch.compose(sch0, sch1, sch2)
	
	sch.to(sch.out_AB, "mm2")
	sch.to(sch.out_ABC, "ele_add")

Buggy output

HLS code for mm2:

void mm2(
  hls::stream< float > &v15 /* v15[180][190] */,
  float v16[190][220],
  hls::stream< float > &v17 /* v17[180][220] */
) {     // L26
  #pragma HLS stream variable=v15 depth=34200
  #pragma HLS stream variable=v17 depth=39600
  l_S_i1_j1_0_i1: for (int i1 = 0; i1 < 180; i1++) {    // L27
    float v19[220];     // L28
    l_j1_init: for (int j1_init = 0; j1_init < 220; j1_init++) {        // L30
    #pragma HLS pipeline II=1
      v19[j1_init] = 0.000000;  // L31
    }
    l_S_k1_0_k1: for (int k1 = 0; k1 < 190; k1++) {     // L33
    #pragma HLS pipeline II=1
      l_j1: for (int j1 = 0; j1 < 220; j1++) {  // L34
        float v23 = v15.read(); // v15[i1][k1]; // L35
        float v24 = v16[k1][j1];        // L36
        float v25 = v23 * v24;  // L37
        float v26 = v19[j1];    // L38
        float v27 = v26 + v25;  // L39
        v19[j1] = v27;  // L40
      }
    }
    l_j1_back: for (int j1_back = 0; j1_back < 220; j1_back++) {        // L43
    #pragma HLS pipeline II=1
      float v29 = v19[j1_back]; // L44
      v17.write(v29); // v17[i1][j1_back] = v29;        // L45
    }
  }
}

Expected behavior
The v15.read() should be moved outside the reduction loop.

void mm2(
  hls::stream< float > &v15 /* v15[180][190] */,
  float v16[190][220],
  hls::stream< float > &v17 /* v17[180][220] */
) {     // L26
  #pragma HLS stream variable=v15 depth=34200
  #pragma HLS stream variable=v17 depth=39600
  l_S_i1_j1_0_i1: for (int i1 = 0; i1 < 180; i1++) {    // L27
    float v19[220];     // L28
    l_j1_init: for (int j1_init = 0; j1_init < 220; j1_init++) {        // L30
    #pragma HLS pipeline II=1
      v19[j1_init] = 0.000000;  // L31
    }
    l_S_k1_0_k1: for (int k1 = 0; k1 < 190; k1++) {     // L33
    #pragma HLS pipeline II=1
        float v23 = v15.read(); // v15[i1][k1]; // L35
      l_j1: for (int j1 = 0; j1 < 220; j1++) {  // L34
        float v24 = v16[k1][j1];        // L36
        float v25 = v23 * v24;  // L37
        float v26 = v19[j1];    // L38
        float v27 = v26 + v25;  // L39
        v19[j1] = v27;  // L40
      }
    }
    l_j1_back: for (int j1_back = 0; j1_back < 220; j1_back++) {        // L43
    #pragma HLS pipeline II=1
      float v29 = v19[j1_back]; // L44
      v17.write(v29); // v17[i1][j1_back] = v29;        // L45
    }
  }
}

[Feature] Rewind memory access loops

Issue
Currently the loops for accessing top level variables that are automatically generated with pipelining at II=1 which is great. However, in my testing this can still lead to 10x the theoretical runtime for 2d arrays.

Solution
Adding rewind to the end of the automatically generated pipeline pragmatism fully solves this performance issue while sometimes also reducing hardware usage.

Example - My matrix vector multiply program.
Without rewind (current setup):
image
78 cycle interval for buf1

With rewind manually added:
image
4 cycle interval achieved.

[BUG]FFN memory usage not optmized due to difficulty of fusing linalg ops

Describe the bug
A linear layer or activation function should allocate a maximum of 1 memory block. And there are possibilities of reuse between different layers. However, current builder allocates a memory block each for operations such as transpose or broadcast. Supposedly the overhead can be eliminated by fusing linalg.transpose with linalg.fill and linalg.matmul. I tried rewriting mlp.py using for loops and succeeded in using only 2 allocations.

To Reproduce
Run mlp.py with monitor_memory and without enable_tensor. The total number of allocations is ten.
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| name | shape | dtype | mem(bits) | BRAM(18K) | store counts | data storage |
+===========+==========+=========+=============+=============+================+============================================================================+
| %alloc | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %4 = memref.load %0[%arg1, %arg2] : memref<30x30xf32> |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_3 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %8 = arith.addf %6, %7 : f32 |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_10 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %4 = memref.load %1[%arg2] : memref<30xf32> |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_14 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %6 = arith.addf %4, %5 : f32 |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_21 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %4 = memref.load %2[%arg1, %arg2] : memref<30x30xf32> |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_28 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %8 = arith.addf %6, %7 : f32 |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_35 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %4 = memref.load %3[%arg2] : memref<30xf32> |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_39 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %6 = arith.addf %4, %5 : f32 |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_46 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %6 = arith.maxf %4, %5 : f32 |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_50 | [30, 30] | f32 | 28800 | 1.6384e+06 | 0 | |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| Total(10) | | | 288000 | 1.6384e+07 | | *data storage: data stored into an allocated memory. Doesn't include init. |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+

Expected behavior
If rewriting the FFN with for loops:

def test_for_FFN():
    w1_const = np.float32(np.random.uniform(size=(30, 30)))
    w2_const = np.float32(np.random.uniform(size=(30, 30)))
    b1_const = np.float32(np.random.uniform(size=(30)))
    b2_const = np.float32(np.random.uniform(size=(30)))
    def kernel(A: float32[30, 30]) -> float32[30, 30]:
        w1: float32[30, 30] = w1_const
        w2: float32[30, 30] = w2_const
        b1: float32[30] = b1_const
        b2: float32[30] = b2_const
        B: float32[30, 30] = 0
        for i, j in allo.grid(30, 30):
            for k in allo.reduction(30):
                B[i, j] += A[k, i] * w1[k, j]
                B[i, j] += b1[j]
        C: float32[30, 30] = 0
        for i, j in allo.grid(30, 30):
            for k in allo.reduction(30):
                C[i, j] += B[k, i] * w2[k, j]
                C[i, j] += b2[j]
        for i, j in allo.grid(30, 30):
            C[i, j] = allo.max(C[i, j], 0.0)
        return C

    s = allo.customize(kernel, verbose=True)
    print(s.module)
    mod = s.build()
    monitor_memory_table = monitor_memory_usage(mod.intermediate_module)
    print(monitor_memory_table)

There are only 2 allocations, which means our current builder is way from optimal.

[Feature] Comment block support

Problem
The standard python convention for docstrings requires that it is in the form of a block comment (using """). However, in allo functions comments blocks aren't properly ignored and throw an error.

Solution
Support for comment blocks would allow for standard doc strings to work.

[BUG] Installation Issue

Describe the bug
Installation fails with CMake error

To Reproduce
bash build.sh at checkout hash e2f1e16.

Buggy output

-- Targeting X86
-- Performing Test C_SUPPORTS_WERROR_IMPLICIT_FUNCTION_DECLARATION                                      
-- Performing Test C_SUPPORTS_WERROR_IMPLICIT_FUNCTION_DECLARATION - Success                            
-- Performing Test C_SUPPORTS_WERROR_MISMATCHED_TAGS                                                    
-- Performing Test C_SUPPORTS_WERROR_MISMATCHED_TAGS - Failed                                           
-- Found Python3: /tools/Tools/MiniConda/envs/allo/bin/python3 (found suitable version "3.12.1", minimum
required is "3.6") found components: Interpreter Development Development.Module Development.Embed       
CMake Error at /usr/share/cmake/Modules/FindPackageHandleStandardArgs.cmake:230 (message):              
  Could NOT find Python3 (missing: Python3_NumPy_INCLUDE_DIRS NumPy) (found                             
  suitable version "3.12.1", minimum required is "3.6")                                                 
Call Stack (most recent call first):
  /usr/share/cmake/Modules/FindPackageHandleStandardArgs.cmake:600 (_FPHSA_FAILURE_MESSAGE)             
  /usr/share/cmake/Modules/FindPython/Support.cmake:3766 (find_package_handle_standard_args)            
  /usr/share/cmake/Modules/FindPython3.cmake:551 (include)                                              
  /tools/ToolSource/allo/externals/hcl_mlir/externals/llvm-project/mlir/cmake/modules/MLIRDetectPythonEnv.cmake:21 (find_package)
  /tools/ToolSource/allo/externals/hcl_mlir/externals/llvm-project/mlir/CMakeLists.txt:156 (mlir_configure_python_dev_packages)


-- Configuring incomplete, errors occurred!
make: *** No targets specified and no makefile found.  Stop.                                            
LLVM build directory: /tools/ToolSource/allo/externals/hcl_mlir/externals/llvm-project/build            
Building hcl dialect ...
-- The CXX compiler identification is GNU 8.5.0
-- The C compiler identification is GNU 8.5.0
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped                                               
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped                                                  
-- Detecting C compile features
-- Detecting C compile features - done
CMake Error at CMakeLists.txt:16 (find_package):
  Could not find a package configuration file provided by "MLIR" with any of                            
  the following names:

    MLIRConfig.cmake
    mlir-config.cmake

  Add the installation prefix of "MLIR" to CMAKE_PREFIX_PATH or set                                     
  "MLIR_DIR" to a directory containing one of the above files.  If "MLIR"                               
  provides a separate development package or SDK, be sure it has been                                   
  installed.


-- Configuring incomplete, errors occurred!
make: *** No targets specified and no makefile found.  Stop.                                            
build.sh: line 48: cd: tools/hcl/python_packages/hcl_core: No such file or directory                    
Obtaining file:///tools/ToolSource/allo/externals/hcl_mlir/build                                        
ERROR: file:///tools/ToolSource/allo/externals/hcl_mlir/build does not appear to be a Python project: neither 'setup.py' nor 'pyproject.toml' found.
Installation completed!

Additional context
Add any other context about the problem here.

Created conda virtual environment using conda from Miniconda (https://docs.conda.io/projects/miniconda/en/latest/).

(allo) [dpal2@pal-achieve-01 allo]$ which python
/tools/Tools/MiniConda/envs/allo/bin/python
(allo) [dpal2@pal-achieve-01 allo]$ python --version
Python 3.12.1
(allo) [dpal2@pal-achieve-01 allo]$ 

GCC = 8.5.0 (Red Hat 8.5)

[BUG] Randomly failed MLIR module when cascading .partition() and .compose()

The following program will randomly fail when building the MLIR module. I haven't figured out the reason. Since sometimes it works and sometimes it cannot, it makes debugging extremely painful.

def test_partition_and_compose():
    inp_num = 12
    inp_len = 768
    Max_size = 12

    def Linear_layer_q(
        inp: float32[inp_num, inp_len],
        W: float32[inp_len, inp_len],
        B: float32[inp_len],
    ) -> float32[inp_num, inp_len]:
        outp: float32[inp_num, inp_len]
        for i, j in allo.grid(inp_num, inp_len, name="bias"):
            outp[i, j] = B[j]
        for i, j, k in allo.grid(inp_num, inp_len, inp_len, name="gemm"):
            outp[i, j] += inp[i, k] * W[j, k]
        return outp

    def top(
        inp: float32[inp_num, inp_len],
        W: float32[inp_len, inp_len],
        B: float32[inp_len],
    ) -> float32[inp_num, inp_len]:
        outp: float32[inp_num, inp_len]
        outp = Linear_layer_q(inp, W, B)
        return outp

    s_q = allo.customize(Linear_layer_q)
    s_q.partition(Linear_layer_q.inp, partition_type=2, dim=1, factor=Max_size)
    s_q.partition(Linear_layer_q.W, partition_type=2, dim=1, factor=Max_size)
    s = allo.customize(top)
    s.compose(s_q)
    print(s.build(target="vhls"))
python3: /scratch/users/hc676/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:973: static mlir::python::PyOperationRef mlir::python::PyOperation::createDetached(mlir::python::PyMlirContextRef, MlirOperation, pybind11::object): Assertion `liveOperations.count(operation.ptr) == 0 && "cannot create detached operation that already exists"' failed.
 #0 0x00007effa71ff1cf PrintStackTraceSignalHandler(void*) Signals.cpp:0:0
 #1 0x00007effa71fcebc SignalHandler(int) Signals.cpp:0:0
 #2 0x00007f002d8f6630 __restore_rt sigaction.c:0:0
 #3 0x00007f002ce46387 raise (/lib64/libc.so.6+0x36387)
 #4 0x00007f002ce47a78 abort (/lib64/libc.so.6+0x37a78)
 #5 0x00007f002ce3f1a6 __assert_fail_base (/lib64/libc.so.6+0x2f1a6)
 #6 0x00007f002ce3f252 (/lib64/libc.so.6+0x2f252)
 #7 0x00007effad234839 mlir::python::PyOperation::createDetached(mlir::python::PyObjectRef<mlir::python::PyMlirContext>, MlirOperation, pybind11::object) /scratch/users/hc676/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:972:0
 #8 0x00007effad236211 mlir::python::PyOperation::create(std::string const&, llvm::Optional<std::vector<mlir::python::PyType*, std::allocator<mlir::python::PyType*>>>, llvm::Optional<std::vector<mlir::python::PyValue*, std::allocator<mlir::python::PyValue*>>>, llvm::Optional<pybind11::dict>, llvm::Optional<std::vector<mlir::python::PyBlock*, std::allocator<mlir::python::PyBlock*>>>, int, mlir::python::DefaultingPyLocation, pybind11::object const&) /scratch/users/hc676/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1229:0
 #9 0x00007effad2386f0 mlir::python::PyOpView::buildGeneric(pybind11::object const&, pybind11::list, pybind11::list, llvm::Optional<pybind11::dict>, llvm::Optional<std::vector<mlir::python::PyBlock*, std::allocator<mlir::python::PyBlock*>>>, llvm::Optional<int>, mlir::python::DefaultingPyLocation, pybind11::object const&) /scratch/users/hc676/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1525:0
#10 0x00007effad2ea786 pybind11::object pybind11::detail::argument_loader<pybind11::object const&, pybind11::list, pybind11::list, llvm::Optional<pybind11::dict>, llvm::Optional<std::vector<mlir::python::PyBlock*, std::allocator<mlir::python::PyBlock*>>>, llvm::Optional<int>, mlir::python::DefaultingPyLocation, pybind11::object const&>::call_impl<pybind11::object, pybind11::object (*&)(pybind11::object const&, pybind11::list, pybind11::list, llvm::Optional<pybind11::dict>, llvm::Optional<std::vector<mlir::python::PyBlock*, std::allocator<mlir::python::PyBlock*>>>, llvm::Optional<int>, mlir::python::DefaultingPyLocation, pybind11::object const&), 0ul, 1ul, 2ul, 3ul, 4ul, 5ul, 6ul, 7ul, pybind11::detail::void_type>(pybind11::object (*&)(pybind11::object const&, pybind11::list, pybind11::list, llvm::Optional<pybind11::dict>, llvm::Optional<std::vector<mlir::python::PyBlock*, std::allocator<mlir::python::PyBlock*>>>, llvm::Optional<int>, mlir::python::DefaultingPyLocation, pybind11::object const&), std::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul, 4ul, 5ul, 6ul, 7ul>, pybind11::detail::void_type&&) && /home/hc676/.conda/envs/hcltest/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1480:0

[BUG]Operations unsupported when generating HLS from lowered MLIR

Describe the bug
When using build(target = "") certain Operations cannot pass hcl_d.emit_vhls(self.module, buf), with outputs in the form of error: '<name of op>' op is unsupported operation. The list of these operations and the testcase to which they belong is provided in Additional Context below.

To Reproduce

def test_while_basic():
    def kernel(A: int32[10]):
        i: index = 0
        while i < 10:
            A[i] = i
            i += 1

    s = allo.customize(kernel, verbose=True)
    print(s.module)
    mod = s.build(target = "vhls")

Buggy output

loc("-":8:5): error: 'scf.while' op is unsupported operation.
 #0 0x00007f76b2eed9d8 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x185b9d8)
 #1 0x00007f76b2eeb25c SignalHandler(int) Signals.cpp:0:0
 #2 0x00007f773a2f7630 __restore_rt sigaction.c:0:0
 #3 0x00007f7739847387 raise (/lib64/libc.so.6+0x36387)
 #4 0x00007f7739848a78 abort (/lib64/libc.so.6+0x37a78)
 #5 0x00007f76b3071f71 /work/shared/users/common/hcl-dialect-18.x/include/hcl/Dialect/Visitor.h:89:10
 #6 0x00007f76b3071cd4 bool mlir::hcl::HLSCppVisitorBase<(anonymous namespace)::ExprVisitor, bool>::dispatchVisitor(mlir::Operation*)::'lambda0'(auto)::operator()<mlir::Operation*>(auto) const /work/shared/users/common/hcl-dialect-18.x/include/hcl/Dialect/Visitor.h:81:18
 #7 0x00007f76b3071ab7 bool llvm::TypeSwitch<mlir::Operation*, bool>::Default<mlir::hcl::HLSCppVisitorBase<(anonymous namespace)::ExprVisitor, bool>::dispatchVisitor(mlir::Operation*)::'lambda0'(auto)>(auto&&) /work/shared/users/common/llvm-project-18.x/llvm/include/llvm/ADT/TypeSwitch.h:131:33
 #8 0x00007f76b306ff33 mlir::hcl::HLSCppVisitorBase<(anonymous namespace)::ExprVisitor, bool>::dispatchVisitor(mlir::Operation*) /work/shared/users/common/hcl-dialect-18.x/include/hcl/Dialect/Visitor.h:84:3
 #9 0x00007f76b306a2f0 (anonymous namespace)::ModuleEmitter::emitBlock(mlir::Block&) /work/shared/users/common/hcl-dialect-18.x/lib/Translation/EmitVivadoHLS.cpp:1692:5
#10 0x00007f76b306bd6e (anonymous namespace)::ModuleEmitter::emitFunction(mlir::func::FuncOp) /work/shared/users/common/hcl-dialect-18.x/lib/Translation/EmitVivadoHLS.cpp:2008:20
#11 0x00007f76b306c306 (anonymous namespace)::ModuleEmitter::emitModule(mlir::ModuleOp) /work/shared/users/common/hcl-dialect-18.x/lib/Translation/EmitVivadoHLS.cpp:2099:21
#12 0x00007f76b306c425 mlir::hcl::emitVivadoHLS(mlir::ModuleOp, llvm::raw_ostream&) /work/shared/users/common/hcl-dialect-18.x/lib/Translation/EmitVivadoHLS.cpp:2115:24
#13 0x00007f76b2c9265e mlirEmitVivadoHls /work/shared/users/common/hcl-dialect-18.x/lib/CAPI/Translation/EmitVivadoHLS.cpp:19:52
#14 0x00007f772f25ff20 emitVivadoHls(MlirModule&, pybind11::object) //work/shared/users/common/hcl-dialect-18.x/lib/Bindings/Python/HCLModule.cpp:91:36
#15 0x00007f772f28869d bool pybind11::detail::argument_loader<MlirModule&, pybind11::object>::call_impl<bool, bool (*&)(MlirModule&, pybind11::object), 0ul, 1ul, pybind11::detail::void_type>(bool (*&)(MlirModule&, pybind11::object), std::integer_sequence<unsigned long, 0ul, 1ul>, pybind11::detail::void_type&&) && /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1480:37
#16 0x00007f772f285713 _ZNO8pybind116detail15argument_loaderIJR10MlirModuleNS_6objectEEE4callIbNS0_9void_typeERPFbS3_S4_EEENSt9enable_ifIXntsrSt7is_voidIT_E5valueESD_E4typeEOT1_ /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1450:5
#17 0x00007f772f280806 void pybind11::cpp_function::initialize<bool (*&)(MlirModule&, pybind11::object), bool, MlirModule&, pybind11::object, pybind11::name, pybind11::scope, pybind11::sibling>(bool (*&)(MlirModule&, pybind11::object), bool (*)(MlirModule&, pybind11::object), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&)::'lambda1'(pybind11::detail::function_call&)::operator()(pybind11::detail::function_call&) const /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:253:40
#18 0x00007f772f280940 void pybind11::cpp_function::initialize<bool (*&)(MlirModule&, pybind11::object), bool, MlirModule&, pybind11::object, pybind11::name, pybind11::scope, pybind11::sibling>(bool (*&)(MlirModule&, pybind11::object), bool (*)(MlirModule&, pybind11::object), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&)::'lambda1'(pybind11::detail::function_call&)::_FUN(pybind11::detail::function_call&) /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:224:21
#19 0x00007f772f26de5a pybind11::cpp_function::dispatcher(_object*, _object*, _object*) /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:946:35
#20 0x00000000004f5592 cfunction_call_varargs /usr/local/src/conda/python-3.8.17/Objects/call.c:745:9
#21 0x00000000004f5592 PyCFunction_Call /usr/local/src/conda/python-3.8.17/Objects/call.c:773:16
#22 0x00000000004e0e1b _PyObject_MakeTpCall /usr/local/src/conda/python-3.8.17/Objects/call.c:159:18
#23 0x00000000004dcf24 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:125:16
#24 0x00000000004dcf24 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:115:1
#25 0x00000000004dcf24 call_function /usr/local/src/conda/python-3.8.17/Python/ceval.c:4963:13
#26 0x00000000004dcf24 _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.8.17/Python/ceval.c:3469:23
#27 0x00000000004d70e1 _PyEval_EvalCodeWithName /usr/local/src/conda/python-3.8.17/Python/ceval.c:4308:9
#28 0x00000000004e823c _PyFunction_Vectorcall /usr/local/src/conda/python-3.8.17/Objects/call.c:436:12
#29 0x00000000004e05d2 _PyObject_FastCallDict /usr/local/src/conda/python-3.8.17/Objects/call.c:105:21
#30 0x00000000004f1c13 _PyObject_Call_Prepend /usr/local/src/conda/python-3.8.17/Objects/call.c:888:14
#31 0x00000000004f1c13 slot_tp_init /usr/local/src/conda/python-3.8.17/Objects/typeobject.c:6790:15
#32 0x00000000004e0e33 type_call /usr/local/src/conda/python-3.8.17/Objects/typeobject.c:995:12
#33 0x00000000004e0e33 _PyObject_MakeTpCall /usr/local/src/conda/python-3.8.17/Objects/call.c:159:18
#34 0x00000000004dd0d6 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:125:16
#35 0x00000000004dd0d6 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:115:1
#36 0x00000000004dd0d6 call_function /usr/local/src/conda/python-3.8.17/Python/ceval.c:4963:13
#37 0x00000000004dd0d6 _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.8.17/Python/ceval.c:3515:19
#38 0x00000000004d70e1 _PyEval_EvalCodeWithName /usr/local/src/conda/python-3.8.17/Python/ceval.c:4308:9
#39 0x00000000004f50fb _PyFunction_Vectorcall /usr/local/src/conda/python-3.8.17/Objects/call.c:436:12
#40 0x00000000004f50fb _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:127:11
#41 0x00000000004f50fb method_vectorcall /usr/local/src/conda/python-3.8.17/Objects/classobject.c:60:18
#42 0x00000000004d9276 PyErr_Occurred /usr/local/src/conda/python-3.8.17/Python/errors.c:221:29
#43 0x00000000004d9276 _Py_CheckFunctionResult /usr/local/src/conda/python-3.8.17/Objects/call.c:25:25
#44 0x00000000004d9276 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:128:12
#45 0x00000000004d9276 call_function /usr/local/src/conda/python-3.8.17/Python/ceval.c:4963:13
#46 0x00000000004d9276 _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.8.17/Python/ceval.c:3515:19
#47 0x00000000004e81a6 function_code_fastcall /usr/local/src/conda/python-3.8.17/Objects/call.c:286:9
#48 0x00000000004e81a6 _PyFunction_Vectorcall /usr/local/src/conda/python-3.8.17/Objects/call.c:411:20
#49 0x00000000004d84b9 PyErr_Occurred /usr/local/src/conda/python-3.8.17/Python/errors.c:221:29
#50 0x00000000004d84b9 _Py_CheckFunctionResult /usr/local/src/conda/python-3.8.17/Objects/call.c:25:25
#51 0x00000000004d84b9 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:128:12
#52 0x00000000004d84b9 call_function /usr/local/src/conda/python-3.8.17/Python/ceval.c:4963:13
#53 0x00000000004d84b9 _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.8.17/Python/ceval.c:3500:19
#54 0x00000000004d70e1 _PyEval_EvalCodeWithName /usr/local/src/conda/python-3.8.17/Python/ceval.c:4308:9
#55 0x0000000000585e99 PyEval_EvalCodeEx /usr/local/src/conda/python-3.8.17/Python/ceval.c:4334:1
#56 0x0000000000585e5b PyEval_EvalCode /usr/local/src/conda/python-3.8.17/Python/ceval.c:724:1
#57 0x00000000005a5c21 run_eval_code_obj /usr/local/src/conda/python-3.8.17/Python/pythonrun.c:1166:9
#58 0x00000000005a4c2f _Py_DECREF /usr/local/src/conda/python-3.8.17/Include/object.h:470:8
#59 0x00000000005a4c2f run_mod /usr/local/src/conda/python-3.8.17/Python/pythonrun.c:1189:5
#60 0x000000000045c580 pyrun_file /usr/local/src/conda/python-3.8.17/Python/pythonrun.c:1085:15
#61 0x000000000045c121 pyrun_simple_file /usr/local/src/conda/python-3.8.17/Python/pythonrun.c:439:13
#62 0x000000000045c121 PyRun_SimpleFileExFlags /usr/local/src/conda/python-3.8.17/Python/pythonrun.c:472:15
#63 0x000000000044fe93 _Py_DECREF /usr/local/src/conda/python-3.8.17/Include/object.h:470:8
#64 0x000000000044fe93 _Py_XDECREF /usr/local/src/conda/python-3.8.17/Include/object.h:541:9
#65 0x000000000044fe93 pymain_run_file /usr/local/src/conda/python-3.8.17/Modules/main.c:392:5
#66 0x000000000044fe93 pymain_run_python /usr/local/src/conda/python-3.8.17/Modules/main.c:616:21
#67 0x000000000044fe93 Py_RunMain.cold /usr/local/src/conda/python-3.8.17/Modules/main.c:695:5
#68 0x0000000000579ef9 Py_BytesMain /usr/local/src/conda/python-3.8.17/Modules/main.c:1128:1
#69 0x00007f7739833555 __libc_start_main (/lib64/libc.so.6+0x22555)
#70 0x0000000000579dad _start (/home/md2249/miniconda3/envs/allo/bin/python3.8+0x579dad)
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Aborted

Expected behavior
Should output HLS code in the form given in the Allo tutorial.

Additional context

  • test_builder.py
    scf.while: test_while_basic()
    arith.floordivsi: test_rhs_binaryop()

[BUG] Cannot find target loop at schedule composition

Describe the bug
A loop cannot be found only during schedule composition, works if customized alone.

To Reproduce

import allo
from allo.ir.types import int32

N = 256
M = 16


def compute_dist(
    position_x: int32[N],
    position_y: int32[N],
    position_z: int32[N],
    NL: int32[N, M],
    del_x: int32[N],
    del_y: int32[N],
    del_z: int32[N],
):
    for i0, j0 in allo.grid(N, M):
        del_x[i0] = position_x[i0] - position_x[NL[i0, j0]]
        del_y[i0] = position_y[i0] - position_y[NL[i0, j0]]
        del_z[i0] = position_z[i0] - position_z[NL[i0, j0]]


def kernel_md(
    position_x: int32[N],
    position_y: int32[N],
    position_z: int32[N],
    NL: int32[N, M],
    force_x: int32[N],
    force_y: int32[N],
    force_z: int32[N],
):
    del_x: int32[N]
    del_y: int32[N]
    del_z: int32[N]
    compute_dist(position_x, position_y, position_z, NL, del_x, del_y, del_z)


s0 = allo.customize(compute_dist)
s0.split("i0", factor=16)
s0.pipeline("i0.inner")
s = allo.customize(kernel_md)
s.compose(s0)

Buggy output

Traceback (most recent call last):
  File "/work/shared/users/phd/nz264/machsuite-allo/md/knn/md_opt.py", line 42, in <module>
    s.compose(s0)
  File "/work/shared/users/phd/nz264/allo/allo/customize.py", line 108, in wrapper
    res = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/work/shared/users/phd/nz264/allo/allo/customize.py", line 676, in compose
    primitive_func.__wrapped__(self, *args, **kwargs)
  File "/work/shared/users/phd/nz264/allo/allo/customize.py", line 403, in pipeline
    band_name, axis = find_loop_in_bands(func, axis)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/shared/users/phd/nz264/allo/allo/ir/transform.py", line 149, in find_loop_in_bands
    raise RuntimeError(f"Cannot find the band of loop {axis_name}")
RuntimeError: Cannot find the band of loop i0.inner

Expected behavior
Target loop should be found during schedule replay.

Additional context
This is part of the md/knn/ kernel in MachSuite.

[BUG] Cannot lower loops with tensors to llvm when using tensor dialect

Describe the bug
When operating tensors in a For loop using tensor dialect, such as inserting value into slices, module cannot be lowered to llvm correctly. tests will report error: unknown: operand #0 does not dominate this use and the exact operation.

To Reproduce
in tests/test_linalg.py def test_math_scalar()

def kernel(A: float32[M, K], B: float32[K, N]) -> float32[M, N]:
        C: float32[M, N] = 0.0
        D: float32[M, N] = 0.0
        for i, j in allo.grid(M, N):
            for k in allo.reduction(K):
                C[i, j] += A[i, k] * B[k, j]
        for i, j in allo.grid(M, N):
            D[i, j] = (allo.exp(C[i, j]) + allo.log(C[i, j])) / C[i, j]
        return D

the module after customize() is

"builtin.module"() ({
  "func.func"() <{function_type = (tensor<10x15xf32>, tensor<15x20xf32>) -> tensor<10x20xf32>, sym_name = "kernel"}> ({
  ^bb0(%arg0: tensor<10x15xf32>, %arg1: tensor<15x20xf32>):
    %0 = "tensor.generate"() ({
    ^bb0(%arg2: index, %arg3: index):
      %2 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
      "tensor.yield"(%2) : (f32) -> ()
    }) : () -> tensor<10x20xf32>
    %1 = "tensor.generate"() ({
    ^bb0(%arg2: index, %arg3: index):
      %2 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
      "tensor.yield"(%2) : (f32) -> ()
    }) : () -> tensor<10x20xf32>
    "affine.for"() ({
    ^bb0(%arg2: index):
      "affine.for"() ({
      ^bb0(%arg3: index):
        "affine.for"() ({
        ^bb0(%arg4: index):
          %2 = "tensor.extract"(%arg0, %arg2, %arg4) : (tensor<10x15xf32>, index, index) -> f32
          %3 = "tensor.extract"(%arg1, %arg4, %arg3) : (tensor<15x20xf32>, index, index) -> f32
          %4 = "arith.mulf"(%2, %3) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
          %5 = "tensor.extract"(%0, %arg2, %arg3) {from = "C"} : (tensor<10x20xf32>, index, index) -> f32
          %6 = "arith.addf"(%5, %4) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
          %7 = "tensor.insert"(%6, %0, %arg2, %arg3) : (f32, tensor<10x20xf32>, index, index) -> tensor<10x20xf32>
          "affine.yield"() : () -> ()
        }) {loop_name = "k", lower_bound = #map, op_name = "S_k_0", reduction, step = 1 : i32, upper_bound = #map1} : () -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map, step = 1 : i32, upper_bound = #map2} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map, op_name = "S_i_j_0", step = 1 : i32, upper_bound = #map3} : () -> ()
    "affine.for"() ({
    ^bb0(%arg2: index):
      "affine.for"() ({
      ^bb0(%arg3: index):
        %2 = "tensor.extract"(%0, %arg2, %arg3) : (tensor<10x20xf32>, index, index) -> f32
        %3 = "math.exp"(%2) <{fastmath = #arith.fastmath<none>}> : (f32) -> f32
        %4 = "tensor.extract"(%0, %arg2, %arg3) : (tensor<10x20xf32>, index, index) -> f32
        %5 = "math.log"(%4) <{fastmath = #arith.fastmath<none>}> : (f32) -> f32
        %6 = "arith.addf"(%3, %5) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
        %7 = "tensor.extract"(%0, %arg2, %arg3) : (tensor<10x20xf32>, index, index) -> f32
        %8 = "arith.divf"(%6, %7) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
        %9 = "tensor.insert"(%8, %1, %arg2, %arg3) : (f32, tensor<10x20xf32>, index, index) -> tensor<10x20xf32>
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map, step = 1 : i32, upper_bound = #map2} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map, op_name = "S_i_j_2", step = 1 : i32, upper_bound = #map3} : () -> ()
    "func.return"(%9) : (tensor<10x20xf32>) -> ()
  }) {itypes = "__", otypes = "_"} : () -> ()
}) : () -> ()

Buggy output

File "/home/md2249/allo/allo/passes.py", line 25, in _mlir_lower_pipeline
  mlir_pass_manager.parse(pipeline).run(module.operation)
hcl_mlir._mlir_libs._site_initialize.<locals>.MLIRError: Failure while executing pass pipeline:
error: unknown: operand #0 does not dominate this use
note: unknown: see current operation: "func.return"(%9) : (tensor<10x20xf32>) -> ()
note: unknown: operand defined here (op in a child region)

Expected behavior
Already used pass in allo/backend/llvm.py class LLVMModule

pm = PassManager.parse(
                "builtin.module("
                # used for lowering tensor.empty
                "empty-tensor-to-alloc-tensor,"
                # translate tensor dialect (virtual) to memref dialect (physical)
                "one-shot-bufferize{allow-return-allocs bufferize-function-boundaries},"
                # used for lowering memref.subview
                "expand-strided-metadata,"
                # common lowering passes
                "func.func(convert-linalg-to-affine-loops),lower-affine"
                ")"
            )

Hope to lower loop cases to llvm without reporting error

Additional context
Add any other context about the problem here.

[BUG]MLIR module after lowering cannot generate HLS when using certain tensor ops

Describe the bug
When trying to implement build(target="vhls") on test cases in test_tensor.py, after lowering modules using the same passes added to allo/backend/llvm.py, all cases except test_same(), test_same_scalar(), test_outzero(), test_outzero_scalar() cannot pass the line hcl_d.emit_vhls(self.module, buf) in allo/backend/hls.py.

To Reproduce
Run test_extract()

def test_extract():
    def extract(A: int32[6, 6]) -> int32[1, 2]:
        return A[1:2, 1:3]

    s = allo.customize(extract, enable_tensor=True)
    print(s.module)

    mod = s.build(target="vhls")
    print(mod)

Buggy output

python3: /work/shared/users/common/llvm-project-18.x/llvm/include/llvm/ADT/ArrayRef.h:257: const T& llvm::ArrayRef<T>::operator[](size_t) const [with T = mlir::AffineExpr; size_t = long unsigned int]: Assertion `Index < Length && "Invalid index!"' failed.
 #0 0x00007f1ee69329d8 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x185b9d8)
 #1 0x00007f1ee693025c SignalHandler(int) Signals.cpp:0:0
 #2 0x00007f1f6e914630 __restore_rt sigaction.c:0:0
 #3 0x00007f1f6de64387 raise (/lib64/libc.so.6+0x36387)
 #4 0x00007f1f6de65a78 abort (/lib64/libc.so.6+0x37a78)
 #5 0x00007f1f6de5d1a6 __assert_fail_base (/lib64/libc.so.6+0x2f1a6)
 #6 0x00007f1f6de5d252 (/lib64/libc.so.6+0x2f252)
 #7 0x00007f1ee6710983 (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x1639983)
 #8 0x00007f1ee6713f47 (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x163cf47)
 #9 0x00007f1ee8eac24a mlir::hcl::getPartitionFactors(mlir::MemRefType, llvm::SmallVector<long, 8u>*) /work/shared/users/common/hcl-dialect-18.x/lib/Support/Utils.cpp:393:42
#10 0x00007f1ee6aaf747 (anonymous namespace)::ModuleEmitter::emitArrayDirectives(mlir::Value) /work/shared/users/common/hcl-dialect-18.x/lib/Translation/EmitVivadoHLS.cpp:1740:18
#11 0x00007f1ee6aafdc7 (anonymous namespace)::ModuleEmitter::emitFunctionDirectives(mlir::func::FuncOp, llvm::ArrayRef<mlir::Value>) /work/shared/users/common/hcl-dialect-18.x/lib/Translation/EmitVivadoHLS.cpp:1885:3
#12 0x00007f1ee6ab0cf1 (anonymous namespace)::ModuleEmitter::emitFunction(mlir::func::FuncOp) /work/shared/users/common/hcl-dialect-18.x/lib/Translation/EmitVivadoHLS.cpp:2004:20
#13 0x00007f1ee6ab1306 (anonymous namespace)::ModuleEmitter::emitModule(mlir::ModuleOp) /work/shared/users/common/hcl-dialect-18.x/lib/Translation/EmitVivadoHLS.cpp:2099:21
#14 0x00007f1ee6ab1425 mlir::hcl::emitVivadoHLS(mlir::ModuleOp, llvm::raw_ostream&) /work/shared/users/common/hcl-dialect-18.x/lib/Translation/EmitVivadoHLS.cpp:2115:24
#15 0x00007f1ee66d765e mlirEmitVivadoHls /work/shared/users/common/hcl-dialect-18.x/lib/CAPI/Translation/EmitVivadoHLS.cpp:19:52
#16 0x00007f1ee4749f20 emitVivadoHls(MlirModule&, pybind11::object) //work/shared/users/common/hcl-dialect-18.x/lib/Bindings/Python/HCLModule.cpp:91:36
#17 0x00007f1ee477269d bool pybind11::detail::argument_loader<MlirModule&, pybind11::object>::call_impl<bool, bool (*&)(MlirModule&, pybind11::object), 0ul, 1ul, pybind11::detail::void_type>(bool (*&)(MlirModule&, pybind11::object), std::integer_sequence<unsigned long, 0ul, 1ul>, pybind11::detail::void_type&&) && /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1480:37
#18 0x00007f1ee476f713 _ZNO8pybind116detail15argument_loaderIJR10MlirModuleNS_6objectEEE4callIbNS0_9void_typeERPFbS3_S4_EEENSt9enable_ifIXntsrSt7is_voidIT_E5valueESD_E4typeEOT1_ /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1450:5
#19 0x00007f1ee476a806 void pybind11::cpp_function::initialize<bool (*&)(MlirModule&, pybind11::object), bool, MlirModule&, pybind11::object, pybind11::name, pybind11::scope, pybind11::sibling>(bool (*&)(MlirModule&, pybind11::object), bool (*)(MlirModule&, pybind11::object), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&)::'lambda1'(pybind11::detail::function_call&)::operator()(pybind11::detail::function_call&) const /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:253:40
#20 0x00007f1ee476a940 void pybind11::cpp_function::initialize<bool (*&)(MlirModule&, pybind11::object), bool, MlirModule&, pybind11::object, pybind11::name, pybind11::scope, pybind11::sibling>(bool (*&)(MlirModule&, pybind11::object), bool (*)(MlirModule&, pybind11::object), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&)::'lambda1'(pybind11::detail::function_call&)::_FUN(pybind11::detail::function_call&) /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:224:21
#21 0x00007f1ee4757e5a pybind11::cpp_function::dispatcher(_object*, _object*, _object*) /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:946:35
#22 0x00000000004f5592 cfunction_call_varargs /usr/local/src/conda/python-3.8.17/Objects/call.c:745:9
#23 0x00000000004f5592 PyCFunction_Call /usr/local/src/conda/python-3.8.17/Objects/call.c:773:16
#24 0x00000000004e0e1b _PyObject_MakeTpCall /usr/local/src/conda/python-3.8.17/Objects/call.c:159:18
#25 0x00000000004dcf24 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:125:16
#26 0x00000000004dcf24 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:115:1
#27 0x00000000004dcf24 call_function /usr/local/src/conda/python-3.8.17/Python/ceval.c:4963:13
#28 0x00000000004dcf24 _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.8.17/Python/ceval.c:3469:23
#29 0x00000000004d70e1 _PyEval_EvalCodeWithName /usr/local/src/conda/python-3.8.17/Python/ceval.c:4308:9
#30 0x00000000004e823c _PyFunction_Vectorcall /usr/local/src/conda/python-3.8.17/Objects/call.c:436:12
#31 0x00000000004e05d2 _PyObject_FastCallDict /usr/local/src/conda/python-3.8.17/Objects/call.c:105:21
#32 0x00000000004f1c13 _PyObject_Call_Prepend /usr/local/src/conda/python-3.8.17/Objects/call.c:888:14
#33 0x00000000004f1c13 slot_tp_init /usr/local/src/conda/python-3.8.17/Objects/typeobject.c:6790:15
#34 0x00000000004e0e33 type_call /usr/local/src/conda/python-3.8.17/Objects/typeobject.c:995:12
#35 0x00000000004e0e33 _PyObject_MakeTpCall /usr/local/src/conda/python-3.8.17/Objects/call.c:159:18
#36 0x00000000004dd0d6 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:125:16
#37 0x00000000004dd0d6 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:115:1
#38 0x00000000004dd0d6 call_function /usr/local/src/conda/python-3.8.17/Python/ceval.c:4963:13
#39 0x00000000004dd0d6 _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.8.17/Python/ceval.c:3515:19
#40 0x00000000004d70e1 _PyEval_EvalCodeWithName /usr/local/src/conda/python-3.8.17/Python/ceval.c:4308:9
#41 0x00000000004f50fb _PyFunction_Vectorcall /usr/local/src/conda/python-3.8.17/Objects/call.c:436:12
#42 0x00000000004f50fb _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:127:11
#43 0x00000000004f50fb method_vectorcall /usr/local/src/conda/python-3.8.17/Objects/classobject.c:60:18
#44 0x00000000004d9276 PyErr_Occurred /usr/local/src/conda/python-3.8.17/Python/errors.c:221:29
#45 0x00000000004d9276 _Py_CheckFunctionResult /usr/local/src/conda/python-3.8.17/Objects/call.c:25:25
#46 0x00000000004d9276 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:128:12
#47 0x00000000004d9276 call_function /usr/local/src/conda/python-3.8.17/Python/ceval.c:4963:13
#48 0x00000000004d9276 _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.8.17/Python/ceval.c:3515:19
#49 0x00000000004e81a6 function_code_fastcall /usr/local/src/conda/python-3.8.17/Objects/call.c:286:9
#50 0x00000000004e81a6 _PyFunction_Vectorcall /usr/local/src/conda/python-3.8.17/Objects/call.c:411:20
#51 0x00000000004d84b9 PyErr_Occurred /usr/local/src/conda/python-3.8.17/Python/errors.c:221:29
#52 0x00000000004d84b9 _Py_CheckFunctionResult /usr/local/src/conda/python-3.8.17/Objects/call.c:25:25
#53 0x00000000004d84b9 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.17/Include/cpython/abstract.h:128:12
#54 0x00000000004d84b9 call_function /usr/local/src/conda/python-3.8.17/Python/ceval.c:4963:13
#55 0x00000000004d84b9 _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.8.17/Python/ceval.c:3500:19
#56 0x00000000004d70e1 _PyEval_EvalCodeWithName /usr/local/src/conda/python-3.8.17/Python/ceval.c:4308:9
#57 0x0000000000585e99 PyEval_EvalCodeEx /usr/local/src/conda/python-3.8.17/Python/ceval.c:4334:1
#58 0x0000000000585e5b PyEval_EvalCode /usr/local/src/conda/python-3.8.17/Python/ceval.c:724:1
#59 0x00000000005a5c21 run_eval_code_obj /usr/local/src/conda/python-3.8.17/Python/pythonrun.c:1166:9
#60 0x00000000005a4c2f _Py_DECREF /usr/local/src/conda/python-3.8.17/Include/object.h:470:8
#61 0x00000000005a4c2f run_mod /usr/local/src/conda/python-3.8.17/Python/pythonrun.c:1189:5
#62 0x000000000045c580 pyrun_file /usr/local/src/conda/python-3.8.17/Python/pythonrun.c:1085:15
#63 0x000000000045c121 pyrun_simple_file /usr/local/src/conda/python-3.8.17/Python/pythonrun.c:439:13
#64 0x000000000045c121 PyRun_SimpleFileExFlags /usr/local/src/conda/python-3.8.17/Python/pythonrun.c:472:15
#65 0x000000000044fe93 _Py_DECREF /usr/local/src/conda/python-3.8.17/Include/object.h:470:8
#66 0x000000000044fe93 _Py_XDECREF /usr/local/src/conda/python-3.8.17/Include/object.h:541:9
#67 0x000000000044fe93 pymain_run_file /usr/local/src/conda/python-3.8.17/Modules/main.c:392:5
#68 0x000000000044fe93 pymain_run_python /usr/local/src/conda/python-3.8.17/Modules/main.c:616:21
#69 0x000000000044fe93 Py_RunMain.cold /usr/local/src/conda/python-3.8.17/Modules/main.c:695:5
#70 0x0000000000579ef9 Py_BytesMain /usr/local/src/conda/python-3.8.17/Modules/main.c:1128:1
#71 0x00007f1f6de50555 __libc_start_main (/lib64/libc.so.6+0x22555)
#72 0x0000000000579dad _start (/home/md2249/miniconda3/envs/allo/bin/python3.8+0x579dad)
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Aborted

Expected behavior
Should generate HLS code in the form given in the Allo tutorial.

[BUG] Failing to partition function output

Describe the bug
When the user calls the same subfunction several times in the top-level function and wants to partition their return values in the top-level function region, MLIR compilation fails and would generate several redundant affine_map.

To Reproduce
Example1:

def test_call_partition_2():
    M, N = 2, 2
    def matrix_addi(A: int32[M, N]) -> int32[M, N]:
        B: int32[M, N]
        for i, j in allo.grid(M, N):
            B[i, j] = A[i, j] + 1
        return B

    def top(inp: int32[M, N]) -> int32[M, N]:
        outp: int32[M, N]
        temp1 = matrix_addi(inp)
        temp2 = matrix_addi(inp)
        for i, j in allo.grid(M, N):
            outp[i, j] = temp1[i, j] + temp2[i, j] 
        return outp
    s = allo.customize(top)
    s.partition(top.temp1)
    s.partition(top.temp2)
    print(s.module)

Example2:

def test_call_partition_4():
    M, N = 2, 2
    def matrix_addi(A: int32[M, N]) -> int32[M, N]:
        B: int32[M, N]
        for i, j in allo.grid(M, N):
            B[i, j] = A[i, j] + 1
        return B

    def matrix_add(A: int32[M, N], B: int32[M, N]) -> int32[M, N]:
        C: int32[M, N]
        for i, j in allo.grid(M, N):
            C[i, j] = A[i, j] + B[i, j]
        return B

    def top(inp: int32[M, N]) -> int32[M, N]:
        outp: int32[M, N]
        temp1 = matrix_addi(inp)
        temp2 = matrix_addi(inp)
        outp = matrix_add(temp1, temp2)
        return outp
    s = allo.customize(top)
    s.partition(top.temp1)
    print(s.module)

    f = s.build(target="vhls")
    print(f)

Buggy output
Output of example1:

Error: failed to run MLIR lower pipeline, printing module...
#map = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<() -> (0)>
#map3 = affine_map<() -> (2)>
"builtin.module"() ({
  "func.func"() <{function_type = (memref<2x2xi32>) -> memref<2x2xi32, #map>, sym_name = "matrix_addi"}> ({
  ^bb0(%arg0: memref<2x2xi32>):
    %0 = "memref.alloc"() <{odsOperandSegmentSizes = array<i32: 0, 0>}> {name = "B"} : () -> memref<2x2xi32, #map>
    "affine.for"() ({
    ^bb0(%arg1: index):
      "affine.for"() ({
      ^bb0(%arg2: index):
        %1 = "affine.load"(%arg0, %arg1, %arg2) {from = "A", map = #map1} : (memref<2x2xi32>, index, index) -> i32
        %2 = "arith.constant"() <{value = 1 : i32}> : () -> i32
        %3 = "arith.addi"(%1, %2) : (i32, i32) -> i32
        "affine.store"(%3, %0, %arg1, %arg2) {map = #map1, to = "B"} : (i32, memref<2x2xi32, #map>, index, index) -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map2, step = 1 : i32, upper_bound = #map3} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map2, op_name = "S_i_j_0", step = 1 : i32, upper_bound = #map3} : () -> ()
    "func.return"(%0) : (memref<2x2xi32, #map>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32>) -> memref<2x2xi32>, sym_name = "top"}> ({
  ^bb0(%arg0: memref<2x2xi32>):
    %0 = "memref.alloc"() <{odsOperandSegmentSizes = array<i32: 0, 0>}> {name = "outp"} : () -> memref<2x2xi32>
    %1 = "func.call"(%arg0) <{callee = @matrix_addi}> : (memref<2x2xi32>) -> memref<2x2xi32, #map>
    %2 = "func.call"(%arg0) <{callee = @matrix_addi}> : (memref<2x2xi32>) -> memref<2x2xi32>
    "affine.for"() ({
    ^bb0(%arg1: index):
      "affine.for"() ({
      ^bb0(%arg2: index):
        %3 = "affine.load"(%1, %arg1, %arg2) {from = "temp1", map = #map1} : (memref<2x2xi32, #map>, index, index) -> i32
        %4 = "affine.load"(%2, %arg1, %arg2) {from = "temp2", map = #map1} : (memref<2x2xi32>, index, index) -> i32
        %5 = "arith.addi"(%3, %4) : (i32, i32) -> i32
        "affine.store"(%5, %0, %arg1, %arg2) {map = #map1, to = "outp"} : (i32, memref<2x2xi32>, index, index) -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map2, step = 1 : i32, upper_bound = #map3} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map2, op_name = "S_i_j_0", step = 1 : i32, upper_bound = #map3} : () -> ()
    "func.return"(%0) : (memref<2x2xi32>) -> ()
  }) : () -> ()
}) : () -> ()

Traceback (most recent call last):
  File "test_testing.py", line 248, in <module>
    test_call_partition_2()
  File "test_testing.py", line 179, in test_call_partition_2
    s.partition(top.temp1)
  File "/home/jz2292/allo/allo/customize.py", line 73, in wrapper
    _mlir_lower_pipeline(args[0].module)
  File "/home/jz2292/allo/allo/build_module.py", line 24, in _mlir_lower_pipeline
    raise e
  File "/home/jz2292/allo/allo/build_module.py", line 19, in _mlir_lower_pipeline
    mlir_pass_manager.parse(pipeline).run(module.operation)
hcl_mlir._mlir_libs.MLIRError: Failure while executing pass pipeline:
error: unknown: 'func.call' op result type mismatch at index 0
 note: unknown: see current operation: %2 = "func.call"(%arg0) <{callee = @matrix_addi}> : (memref<2x2xi32>) -> memref<2x2xi32>
 note: unknown:       op result types: 'memref<2x2xi32>'
 note: unknown: function result types: 'memref<2x2xi32, affine_map<(d0, d1) -> (d0, d1, 0, 0)>>'

Output of example2:

Error: failed to run MLIR lower pipeline, printing module...
#map = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<() -> (0)>
#map3 = affine_map<() -> (2)>
"builtin.module"() ({
  "func.func"() <{function_type = (memref<2x2xi32>) -> memref<2x2xi32, #map>, sym_name = "matrix_addi"}> ({
  ^bb0(%arg0: memref<2x2xi32>):
    %0 = "memref.alloc"() <{odsOperandSegmentSizes = array<i32: 0, 0>}> {name = "B"} : () -> memref<2x2xi32, #map>
    "affine.for"() ({
    ^bb0(%arg1: index):
      "affine.for"() ({
      ^bb0(%arg2: index):
        %1 = "affine.load"(%arg0, %arg1, %arg2) {from = "A", map = #map1} : (memref<2x2xi32>, index, index) -> i32
        %2 = "arith.constant"() <{value = 1 : i32}> : () -> i32
        %3 = "arith.addi"(%1, %2) : (i32, i32) -> i32
        "affine.store"(%3, %0, %arg1, %arg2) {map = #map1, to = "B"} : (i32, memref<2x2xi32, #map>, index, index) -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map2, step = 1 : i32, upper_bound = #map3} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map2, op_name = "S_i_j_0", step = 1 : i32, upper_bound = #map3} : () -> ()
    "func.return"(%0) : (memref<2x2xi32, #map>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32, #map>, memref<2x2xi32>) -> memref<2x2xi32>, sym_name = "matrix_add"}> ({
  ^bb0(%arg0: memref<2x2xi32, #map>, %arg1: memref<2x2xi32>):
    %0 = "memref.alloc"() <{odsOperandSegmentSizes = array<i32: 0, 0>}> {name = "C"} : () -> memref<2x2xi32>
    "affine.for"() ({
    ^bb0(%arg2: index):
      "affine.for"() ({
      ^bb0(%arg3: index):
        %1 = "affine.load"(%arg0, %arg2, %arg3) {from = "A", map = #map1} : (memref<2x2xi32, #map>, index, index) -> i32
        %2 = "affine.load"(%arg1, %arg2, %arg3) {from = "B", map = #map1} : (memref<2x2xi32>, index, index) -> i32
        %3 = "arith.addi"(%1, %2) : (i32, i32) -> i32
        "affine.store"(%3, %0, %arg2, %arg3) {map = #map1, to = "C"} : (i32, memref<2x2xi32>, index, index) -> ()
        "affine.yield"() : () -> ()
      }) {loop_name = "j", lower_bound = #map2, step = 1 : i32, upper_bound = #map3} : () -> ()
      "affine.yield"() : () -> ()
    }) {loop_name = "i", lower_bound = #map2, op_name = "S_i_j_0", step = 1 : i32, upper_bound = #map3} : () -> ()
    "func.return"(%arg1) : (memref<2x2xi32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<2x2xi32>) -> memref<2x2xi32>, sym_name = "top"}> ({
  ^bb0(%arg0: memref<2x2xi32>):
    %0 = "func.call"(%arg0) <{callee = @matrix_addi}> : (memref<2x2xi32>) -> memref<2x2xi32, #map>
    %1 = "func.call"(%arg0) <{callee = @matrix_addi}> : (memref<2x2xi32>) -> memref<2x2xi32>
    %2 = "func.call"(%0, %1) <{callee = @matrix_add}> : (memref<2x2xi32, #map>, memref<2x2xi32>) -> memref<2x2xi32>
    "func.return"(%2) : (memref<2x2xi32>) -> ()
  }) : () -> ()
}) : () -> ()

Traceback (most recent call last):
  File "test_testing.py", line 250, in <module>
    test_call_partition_4()
  File "test_testing.py", line 233, in test_call_partition_4
    s.partition(top.temp1)
  File "/home/jz2292/allo/allo/customize.py", line 73, in wrapper
    _mlir_lower_pipeline(args[0].module)
  File "/home/jz2292/allo/allo/build_module.py", line 24, in _mlir_lower_pipeline
    raise e
  File "/home/jz2292/allo/allo/build_module.py", line 19, in _mlir_lower_pipeline
    mlir_pass_manager.parse(pipeline).run(module.operation)
hcl_mlir._mlir_libs.MLIRError: Failure while executing pass pipeline:
error: unknown: 'func.call' op result type mismatch at index 0
 note: unknown: see current operation: %1 = "func.call"(%arg0) <{callee = @matrix_addi}> : (memref<2x2xi32>) -> memref<2x2xi32>
 note: unknown:       op result types: 'memref<2x2xi32>'
 note: unknown: function result types: 'memref<2x2xi32, affine_map<(d0, d1) -> (d0, d1, 0, 0)>>'

Expected behavior
The builder should detect that the partition operations are targeted to the same object, e.g.๏ผŒ matrix_addi.B in these two examples.

Additional context
When the subfunction is called for only one time in the top-level function, the partition of its return value in the top-level function region will succeed.

[BUG] Builder doesn't build scf for loop when both loop bounds are expressions

Describe the bug
Program crashes when both bounds of a for loop are expressions

To Reproduce

def test():
    for i in range(10):
        for j in range(i, i+1):
            # Implementation

s = allo.customize(test)

Buggy output

python: /work/shared/users/common/llvm-project-18.x/llvm/include/llvm/ADT/STLExtras.h:1320: DerivedT llvm::detail::indexed_accessor_range_base<DerivedT, BaseT, T, PointerT, ReferenceT>::drop_front(size_t) const [with DerivedT = mlir::OperandRange; BaseT = mlir::OpOperand*; T = mlir::Value; PointerT = mlir::Value; ReferenceT = mlir::Value; size_t = long unsigned int]: Assertion `size() >= n && "Dropping more elements than exist"' failed.
 #0 0x00007fdc7f188bc8 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x185bbc8)
 #1 0x00007fdc7f18644c SignalHandler(int) Signals.cpp:0:0
 #2 0x00007fdd2e826630 __restore_rt sigaction.c:0:0
 #3 0x00007fdd2dd76387 raise (/lib64/libc.so.6+0x36387)
 #4 0x00007fdd2dd77a78 abort (/lib64/libc.so.6+0x37a78)
 #5 0x00007fdd2dd6f1a6 __assert_fail_base (/lib64/libc.so.6+0x2f1a6)
 #6 0x00007fdd2dd6f252 (/lib64/libc.so.6+0x2f252)
 #7 0x00007fdc802a80a3 (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x297b0a3)
 #8 0x00007fdc802c9ff7 mlir::detail::RegionBranchOpInterfaceInterfaceTraits::Model<mlir::affine::AffineForOp>::getSuccessorEntryOperands(mlir::detail::RegionBranchOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, std::optional<unsigned int>) (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x299cff7)
 #9 0x00007fdc83d25007 std::optional<mlir::TypeRange> llvm::function_ref<std::optional<mlir::TypeRange> (std::optional<unsigned int>)>::callback_fn<mlir::detail::verifyTypesAlongControlFlowEdges(mlir::Operation*)::'lambda'(std::optional<unsigned int>)>(long, std::optional<unsigned int>) ControlFlowInterfaces.cpp:0:0
#10 0x00007fdc83d27df3 verifyTypesAlongAllEdges(mlir::Operation*, std::optional<unsigned int>, llvm::function_ref<std::optional<mlir::TypeRange> (std::optional<unsigned int>)>) ControlFlowInterfaces.cpp:0:0
#11 0x00007fdc83d28f6d mlir::detail::verifyTypesAlongControlFlowEdges(mlir::Operation*) (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x63fbf6d)
#12 0x00007fdc802ed278 mlir::Op<mlir::affine::AffineForOp, mlir::OpTrait::OneRegion, mlir::OpTrait::VariadicResults, mlir::OpTrait::ZeroSuccessors, mlir::OpTrait::VariadicOperands, mlir::OpTrait::SingleBlockImplicitTerminator<mlir::affine::AffineYieldOp>::Impl, mlir::OpTrait::OpInvariants, mlir::OpTrait::AutomaticAllocationScope, mlir::ConditionallySpeculatable::Trait, mlir::OpTrait::HasRecursiveMemoryEffects, mlir::LoopLikeOpInterface::Trait, mlir::RegionBranchOpInterface::Trait>::verifyRegionInvariants(mlir::Operation*) (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x29c0278)
#13 0x00007fdc7f1dd927 llvm::unique_function<mlir::LogicalResult (mlir::Operation*) const>::operator()(mlir::Operation*) const /work/shared/users/common/llvm-project-18.x/llvm/include/llvm/ADT/FunctionExtras.h:409:3
#14 0x00007fdc802e5dad mlir::RegisteredOperationName::Model<mlir::affine::AffineForOp>::verifyRegionInvariants(mlir::Operation*) (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x29b8dad)
#15 0x00007fdc7f08232e (anonymous namespace)::OperationVerifier::verifyOnExit(mlir::Operation&) Verifier.cpp:0:0
#16 0x00007fdc7f0808b9 (anonymous namespace)::OperationVerifier::verifyOperation(mlir::Operation&) Verifier.cpp:0:0
#17 0x00007fdc7f08264b (anonymous namespace)::OperationVerifier::verifyOnExit(mlir::Operation&) Verifier.cpp:0:0
#18 0x00007fdc7f0808b9 (anonymous namespace)::OperationVerifier::verifyOperation(mlir::Operation&) Verifier.cpp:0:0
#19 0x00007fdc7f0832a6 mlir::verify(mlir::Operation*, bool) (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x17562a6)
#20 0x00007fdc7ef75ad1 verifyOpAndAdjustFlags(mlir::Operation*, mlir::OpPrintingFlags) AsmPrinter.cpp:0:0
#21 0x00007fdc7ef8979d mlir::AsmState::AsmState(mlir::Operation*, mlir::OpPrintingFlags const&, llvm::DenseMap<mlir::Operation*, std::pair<unsigned int, unsigned int>, llvm::DenseMapInfo<mlir::Operation*, void>, llvm::detail::DenseMapPair<mlir::Operation*, std::pair<unsigned int, unsigned int>>>*, mlir::FallbackAsmResourceMap*) (.constprop.0) AsmPrinter.cpp:0:0
#22 0x00007fdc7ef9291e mlir::Operation::print(llvm::raw_ostream&, mlir::OpPrintingFlags const&) (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x166591e)
#23 0x00007fdc7eea71a9 mlirOperationPrintWithFlags (/work/shared/users/common/hcl-dialect-18.x/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x157a1a9)
#24 0x00007fdc862879d4 mlir::python::PyOperationBase::print(pybind11::object, bool, std::optional<long>, bool, bool, bool, bool, bool) /work/shared/users/common/llvm-project-18.x/mlir/lib/Bindings/Python/IRCore.cpp:1166:29
#25 0x00007fdc86287f00 mlir::python::PyOperationBase::getAsm(bool, std::optional<long>, bool, bool, bool, bool, bool) /work/shared/users/common/llvm-project-18.x/mlir/lib/Bindings/Python/IRCore.cpp:1201:8
#26 0x00007fdc86291e17 mlir::python::populateIRCore(pybind11::module_&)::'lambda49'(mlir::python::PyOperationBase&)::operator()(mlir::python::PyOperationBase&) const /work/shared/users/common/llvm-project-18.x/mlir/lib/Bindings/Python/IRCore.cpp:2821:56
#27 0x00007fdc862dc1a4 pybind11::object pybind11::detail::argument_loader<mlir::python::PyOperationBase&>::call_impl<pybind11::object, mlir::python::populateIRCore(pybind11::module_&)::'lambda49'(mlir::python::PyOperationBase&)&, 0ul, pybind11::detail::void_type>(mlir::python::populateIRCore(pybind11::module_&)::'lambda49'(mlir::python::PyOperationBase&)&, std::integer_sequence<unsigned long, 0ul>, pybind11::detail::void_type&&) && /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1481:5
#28 0x00007fdc862d2bec _ZNO8pybind116detail15argument_loaderIJRN4mlir6python15PyOperationBaseEEE4callINS_6objectENS0_9void_typeERZNS3_14populateIRCoreERNS_7module_EEUlS5_E49_EENSt9enable_ifIXntsrSt7is_voidIT_E5valueESG_E4typeEOT1_ /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1450:5
#29 0x00007fdc862c10bc void pybind11::cpp_function::initialize<mlir::python::populateIRCore(pybind11::module_&)::'lambda49'(mlir::python::PyOperationBase&), pybind11::object, mlir::python::PyOperationBase&, pybind11::name, pybind11::is_method, pybind11::sibling, char [44]>(mlir::python::populateIRCore(pybind11::module_&)::'lambda49'(mlir::python::PyOperationBase&)&&, pybind11::object (*)(mlir::python::PyOperationBase&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, char const (&) [44])::'lambda1'(pybind11::detail::function_call&)::operator()(pybind11::detail::function_call&) const /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:253:40
#30 0x00007fdc862c113d void pybind11::cpp_function::initialize<mlir::python::populateIRCore(pybind11::module_&)::'lambda49'(mlir::python::PyOperationBase&), pybind11::object, mlir::python::PyOperationBase&, pybind11::name, pybind11::is_method, pybind11::sibling, char [44]>(mlir::python::populateIRCore(pybind11::module_&)::'lambda49'(mlir::python::PyOperationBase&)&&, pybind11::object (*)(mlir::python::PyOperationBase&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, char const (&) [44])::'lambda1'(pybind11::detail::function_call&)::_FUN(pybind11::detail::function_call&) /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:224:21
#31 0x00007fdc8614bd8e pybind11::cpp_function::dispatcher(_object*, _object*, _object*) /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:946:35
#32 0x00000000004f5572 cfunction_call_varargs /usr/local/src/conda/python-3.8.18/Objects/call.c:745:9
#33 0x00000000004f5572 PyCFunction_Call /usr/local/src/conda/python-3.8.18/Objects/call.c:773:16
#34 0x00000000004e0e1b _PyObject_MakeTpCall /usr/local/src/conda/python-3.8.18/Objects/call.c:159:18
#35 0x00000000004f542a _PyObject_Vectorcall /usr/local/src/conda/python-3.8.18/Include/cpython/abstract.h:125:16
#36 0x00000000004f542a _PyObject_Vectorcall /usr/local/src/conda/python-3.8.18/Include/cpython/abstract.h:115:1
#37 0x00000000004f542a method_vectorcall /usr/local/src/conda/python-3.8.18/Objects/classobject.c:67:20
#38 0x00000000004f75ca PyErr_Occurred /usr/local/src/conda/python-3.8.18/Python/errors.c:221:29
#39 0x00000000004f75ca _Py_CheckFunctionResult /usr/local/src/conda/python-3.8.18/Objects/call.c:25:25
#40 0x00000000004f75ca PyVectorcall_Call /usr/local/src/conda/python-3.8.18/Objects/call.c:210:12
#41 0x00000000004f75ca PyObject_Call /usr/local/src/conda/python-3.8.18/Objects/call.c:228:16
#42 0x00007fdc8615a9b1 pybind11::detail::simple_collector<(pybind11::return_value_policy)1>::call(_object*) const /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1502:47
#43 0x00007fdc8626bfe3 pybind11::object pybind11::detail::object_api<pybind11::detail::accessor<pybind11::detail::accessor_policies::str_attr>>::operator()<(pybind11::return_value_policy)1>() const /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1672:95
#44 0x00007fdc862916e6 mlir::python::populateIRCore(pybind11::module_&)::'lambda38'(pybind11::object)::operator()(pybind11::object) const /work/shared/users/common/llvm-project-18.x/mlir/lib/Bindings/Python/IRCore.cpp:2740:59
#45 0x00007fdc862dba65 pybind11::object pybind11::detail::argument_loader<pybind11::object>::call_impl<pybind11::object, mlir::python::populateIRCore(pybind11::module_&)::'lambda38'(pybind11::object)&, 0ul, pybind11::detail::void_type>(mlir::python::populateIRCore(pybind11::module_&)::'lambda38'(pybind11::object)&, std::integer_sequence<unsigned long, 0ul>, pybind11::detail::void_type&&) && /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1480:91
#46 0x00007fdc862d2728 _ZNO8pybind116detail15argument_loaderIJNS_6objectEEE4callIS2_NS0_9void_typeERZN4mlir6python14populateIRCoreERNS_7module_EEUlS2_E38_EENSt9enable_ifIXntsrSt7is_voidIT_E5valueESE_E4typeEOT1_ /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1450:5
#47 0x00007fdc862befb6 void pybind11::cpp_function::initialize<mlir::python::populateIRCore(pybind11::module_&)::'lambda38'(pybind11::object), pybind11::object, pybind11::object, pybind11::name, pybind11::is_method, pybind11::sibling, char [243]>(mlir::python::populateIRCore(pybind11::module_&)::'lambda38'(pybind11::object)&&, pybind11::object (*)(pybind11::object), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, char const (&) [243])::'lambda1'(pybind11::detail::function_call&)::operator()(pybind11::detail::function_call&) const /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:253:40
#48 0x00007fdc862bf057 void pybind11::cpp_function::initialize<mlir::python::populateIRCore(pybind11::module_&)::'lambda38'(pybind11::object), pybind11::object, pybind11::object, pybind11::name, pybind11::is_method, pybind11::sibling, char [243]>(mlir::python::populateIRCore(pybind11::module_&)::'lambda38'(pybind11::object)&&, pybind11::object (*)(pybind11::object), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, char const (&) [243])::'lambda1'(pybind11::detail::function_call&)::_FUN(pybind11::detail::function_call&) /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:224:21
#49 0x00007fdc8614bd8e pybind11::cpp_function::dispatcher(_object*, _object*, _object*) /work/shared/users/phd/hc676/envs/allo/lib/python3.8/site-packages/pybind11/include/pybind11/pybind11.h:946:35
#50 0x00000000004f5572 cfunction_call_varargs /usr/local/src/conda/python-3.8.18/Objects/call.c:745:9
#51 0x00000000004f5572 PyCFunction_Call /usr/local/src/conda/python-3.8.18/Objects/call.c:773:16
#52 0x00000000004e0e1b _PyObject_MakeTpCall /usr/local/src/conda/python-3.8.18/Objects/call.c:159:18
#53 0x00000000004f542a _PyObject_Vectorcall /usr/local/src/conda/python-3.8.18/Include/cpython/abstract.h:125:16
#54 0x00000000004f542a _PyObject_Vectorcall /usr/local/src/conda/python-3.8.18/Include/cpython/abstract.h:115:1
#55 0x00000000004f542a method_vectorcall /usr/local/src/conda/python-3.8.18/Objects/classobject.c:67:20
#56 0x00000000004e4ea3 _PyObject_Vectorcall.lto_priv.6 /usr/local/src/conda/python-3.8.18/Include/cpython/abstract.h:127:11
#57 0x000000000053a637 _PyObject_FastCall /usr/local/src/conda/python-3.8.18/Include/cpython/abstract.h:147:12
#58 0x000000000053a637 call_unbound /usr/local/src/conda/python-3.8.18/Objects/typeobject.c:1456:16
#59 0x000000000053a637 call_unbound /usr/local/src/conda/python-3.8.18/Objects/typeobject.c:1449:1
#60 0x000000000053a637 call_method /usr/local/src/conda/python-3.8.18/Objects/typeobject.c:1485:14
#61 0x000000000050cc3f PyObject_Str /usr/local/src/conda/python-3.8.18/Objects/object.c:593:12
#62 0x000000000050cc3f PyObject_Str /usr/local/src/conda/python-3.8.18/Objects/object.c:558:1
#63 0x00000000005ae81b PyFile_WriteObject /usr/local/src/conda/python-3.8.18/Objects/fileobject.c:131:17
#64 0x00000000005ae6fb builtin_print /usr/local/src/conda/python-3.8.18/Python/bltinmodule.c:1868:12
#65 0x00000000004dfad2 cfunction_vectorcall_FASTCALL_KEYWORDS /usr/local/src/conda/python-3.8.18/Objects/methodobject.c:442:5
#66 0x00000000004d84a9 PyErr_Occurred /usr/local/src/conda/python-3.8.18/Python/errors.c:221:29
#67 0x00000000004d84a9 _Py_CheckFunctionResult /usr/local/src/conda/python-3.8.18/Objects/call.c:25:25
#68 0x00000000004d84a9 _PyObject_Vectorcall /usr/local/src/conda/python-3.8.18/Include/cpython/abstract.h:128:12
#69 0x00000000004d84a9 call_function /usr/local/src/conda/python-3.8.18/Python/ceval.c:4963:13
#70 0x00000000004d84a9 _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.8.18/Python/ceval.c:3500:19
#71 0x00000000004d70d1 _PyEval_EvalCodeWithName /usr/local/src/conda/python-3.8.18/Python/ceval.c:4308:9
#72 0x0000000000585e29 PyEval_EvalCodeEx /usr/local/src/conda/python-3.8.18/Python/ceval.c:4334:1
#73 0x0000000000585deb PyEval_EvalCode /usr/local/src/conda/python-3.8.18/Python/ceval.c:724:1
#74 0x00000000005a5bd1 run_eval_code_obj /usr/local/src/conda/python-3.8.18/Python/pythonrun.c:1166:9
#75 0x00000000005a4bdf _Py_DECREF /usr/local/src/conda/python-3.8.18/Include/object.h:470:8
#76 0x00000000005a4bdf run_mod /usr/local/src/conda/python-3.8.18/Python/pythonrun.c:1189:5
#77 0x000000000045c538 pyrun_file /usr/local/src/conda/python-3.8.18/Python/pythonrun.c:1085:15
#78 0x000000000045c0d9 pyrun_simple_file /usr/local/src/conda/python-3.8.18/Python/pythonrun.c:439:13
#79 0x000000000045c0d9 PyRun_SimpleFileExFlags /usr/local/src/conda/python-3.8.18/Python/pythonrun.c:472:15
#80 0x000000000044fe8f _Py_DECREF /usr/local/src/conda/python-3.8.18/Include/object.h:470:8
#81 0x000000000044fe8f _Py_XDECREF /usr/local/src/conda/python-3.8.18/Include/object.h:541:9
#82 0x000000000044fe8f pymain_run_file /usr/local/src/conda/python-3.8.18/Modules/main.c:392:5
#83 0x000000000044fe8f pymain_run_python /usr/local/src/conda/python-3.8.18/Modules/main.c:616:21
#84 0x000000000044fe8f Py_RunMain.cold /usr/local/src/conda/python-3.8.18/Modules/main.c:695:5
#85 0x0000000000579e89 Py_BytesMain /usr/local/src/conda/python-3.8.18/Modules/main.c:1128:1
#86 0x00007fdd2dd62555 __libc_start_main (/lib64/libc.so.6+0x22555)
#87 0x0000000000579d3d _start (/home/rl569/miniconda3/envs/allo/bin/python3.8+0x579d3d)
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Aborted (core dumped)

Expected behavior
A schedule for the kernel that we can work with is created

[BUG] Parameters cannot be used in index calculations

Parameters cannot be used in index calculations.

Whether I defined the stride parameter inside or outside avgpool_nchw function, the index computation of A[n, c, h * stride + rh, w * stride + rw] cannot be performed. I've tried to set type int32 and index for stride parameter, and it was useless. However, when I changed stride to 1, the program can run. It seems that parameters cannot be used in index calculations.

The following program is where I encounter this issue.

def test_avgpool_nchw():
    bs = 4
    ic, oc = 16, 16
    ih, iw = 8, 8
    kh, kw = 2, 2
    stride = 1
    oh, ow = (ih - kh) // stride + 1, (iw - kw) // stride + 1
    dtype = float32

    def avgpool_nchw(A: float32[bs, ic, ih, iw]) -> float32[bs, oc, oh, ow]:
        B: float32[bs, oc, oh, ow] = 0.0
        stride: index = 1
        for n, c, h, w in allo.grid(bs, oc, oh, ow):
            v: float32 = 0.0
            for rh, rw in allo.reduction(kh, kw):
                v += A[n, c, h * stride + rh, w * stride + rw]
            B[n, c, h, w] = v / (kh * kw)
        return B

    s = allo.customize(avgpool_nchw)

The following output indicates an error in multiplication.

Traceback (most recent call last):
  File "../working/test_schedule_memory.py", line 549, in <module>
    test_avgpool_nchw()
  File "../working/test_schedule_memory.py", line 507, in test_avgpool_nchw
    s = allo.customize(avgpool_nchw)
  File "/work/shared/users/ugrad/hw783/allo/allo/customize.py", line 412, in customize
    ASTTransformer()(ctx, tree)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 70, in __call__
    return method(ctx, node)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 840, in build_Module
    build_stmt(ctx, stmt)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 70, in __call__
    return method(ctx, node)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 706, in build_FunctionDef
    stmts = build_stmts(ctx, node.body)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 926, in build_stmts
    results.append(build_stmt(ctx, stmt))
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 70, in __call__
    return method(ctx, node)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 248, in build_For
    return ASTTransformer.build_grid_for(ctx, node)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 226, in build_grid_for
    build_stmts(ctx, node.body)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 926, in build_stmts
    results.append(build_stmt(ctx, stmt))
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 70, in __call__
    return method(ctx, node)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 248, in build_For
    return ASTTransformer.build_grid_for(ctx, node)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 226, in build_grid_for
    build_stmts(ctx, node.body)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 926, in build_stmts
    results.append(build_stmt(ctx, stmt))
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 70, in __call__
    return method(ctx, node)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 470, in build_AugAssign
    rhs = build_stmt(ctx, node.value)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 70, in __call__
    return method(ctx, node)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 535, in build_Subscript
    expr = ASTTransformer.build_affine_expr(ctx, index)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 503, in build_affine_expr
    lhs = ASTTransformer.build_affine_expr(ctx, node.left)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 519, in build_affine_expr
    return op(lhs, rhs)
  File "/work/shared/users/ugrad/hw783/allo/allo/ir/builder.py", line 508, in <lambda>
    ast.Mult: lambda l, r: l * r,
TypeError: __mul__(): incompatible function arguments. The following argument types are supported:
    1. (self: hcl_mlir._mlir_libs._mlir.ir.AffineExpr, arg0: hcl_mlir._mlir_libs._mlir.ir.AffineExpr) -> (anonymous namespace)::PyAffineMulExpr
    2. (self: hcl_mlir._mlir_libs._mlir.ir.AffineExpr, arg0: int) -> (anonymous namespace)::PyAffineMulExpr

Invoked with: AffineExpr(d2), None

[BUG] For loop does not run when given int param from function

Describe the bug
Issue with for loop when the range is an int input from function

To Reproduce

SIZE = 5

def param_for_loop(end: int32) -> int32[SIZE]:
    arr: int32[SIZE] = 0
    for i in range(end):
        if i < SIZE:
            arr[i] = 1

    return arr


def const_for_loop() -> int32[SIZE]:
    arr: int32[SIZE] = 0
    for i in range(SIZE):
        arr[i] = 1

    return arr

s1 = allo.customize(param_for_loop)
mod1 = s1.build(target="llvm")

s2 = allo.customize(const_for_loop)
mod2 = s2.build(target="llvm")

print(mod1(5))
print(mod2())

Buggy output

[0 0 0 0 0]
[1 1 1 1 1]

Expected behavior
Both array should be [1 1 1 1 1] if mod1 takes an int >= 5. If < 5 the first array should have a varying number of 1.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.