diff --git a/.gitignore b/.gitignore index 894a44cc06..6c7fef03eb 100644 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,60 @@ venv.bak/ # mypy .mypy_cache/ + +# Windows +Thumbs.db + +# Ignore files built by Visual Studio [code] +*.obj +*.exe +*.pdb +*.user +*.dot +*.jpg +*.aps +*.pch +*.vspscc +*_i.c +*_p.c +*.ncb +*.suo +*.tlb +*.tlh +*.bak +*.cache +*.ilk +[Bb]in +[Dd]ebug*/ +*.lib +*.sbr +obj/ +[Rr]elease*/ +_ReSharper*/ +[Tt]est[Rr]esult* +.vs/ +.vscode/ +src.VC.db +src.VC.VC.opendb +*.exp + +# DaCe +.dacecache/ +out.sdfg +*.dot +*.out +results.log +perf.json +perf*.csv +/dace/frontend/octave/parsetab.py + +# Xilinx +xilinx_vcu1525_* +sdaccel_profile_* +sdaccel_timeline_* + +# NVIDIA +*.nvprof + +# Miscellaneous +*~ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..bebc84450f --- /dev/null +++ b/.gitmodules @@ -0,0 +1,10 @@ +[submodule "dace/external/cub"] + path = dace/external/cub + url = https://github.com/NVlabs/cub.git + branch = 1.8.0 +[submodule "dace/external/moodycamel"] + path = dace/external/moodycamel + url = https://github.com/cameron314/concurrentqueue.git +[submodule "dace/external/hlslib"] + path = dace/external/hlslib + url = https://github.com/definelicht/hlslib.git diff --git a/LICENSE b/LICENSE index 30b39653b3..a1997075fc 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2019, SPCL +Copyright (c) 2019, Scalable Parallel Computing Lab, ETH Zurich All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index ccc37d3460..b1947d00b6 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,90 @@ -# dace -DaCe - Data Centric Parallel Programming +![D](dace.svg)aCe - Data-Centric Parallel Programming +===================================================== + +_Decoupling domain science from performance optimization._ + +DaCe compiles code in various programming languages and paradigms (Python/Numpy, MATLAB, TensorFlow) and maps it efficiently to **CPUs, GPUs, and FPGAs** with high utilization, on par with the state-of-the-art. The key feature driving DaCe is its Stateful DataFlow multiGraph (SDFG) *data-centric intermediate representation*: A transformable, interactive representation of code based on data movement. +With data-centric parallel programming, we enable **direct knowledge transfer** of performance optimization, regardless of the scientific application or the target processor. + +DaCe can be written inline in Python and transformed in the command-line, or SDFGs can be interactively modified using the Data-centric Interactive Optimization Development Environment (DIODE). + +For more information, see our [paper](http://www.arxiv.org/abs/1902.10345). + +Tutorials +--------- + +* _Implicit Dataflow in Python (coming soon)_ +* [Explicit Dataflow in Python](tutorials/explicit.ipynb) +* [SDFG API](tutorials/sdfg_api.ipynb) +* [Transformations](tutorials/transformations.ipynb) + +Installation and Dependencies +----------------------------- + +To install: `pip install dace` + +Runtime dependencies: + * A C++14-capable compiler (e.g., gcc 5.3+) + * Python 3.5 or newer + +Running DIODE may require additional dependencies: + * `sudo apt-get install libgtksourceviewmm-3.0-dev libyaml-dev` + * `sudo apt-get install python3-cairo python3-gi-cairo libgirepository1.0-dev xdot libwebkitgtk-dev libwebkitgtk-3.0-dev libwebkit2gtk-4.0-dev` + * `pip install pygobject matplotlib` + +To run DIODE on Windows, use MSYS2: + * Download from http://www.msys2.org/ + * In the MSYS2 console, install all dependencies: `pacman -S mingw-w64-i686-gtk3 mingw-w64-i686-python2-gobject mingw-w64-i686-python3-gobject mingw-w64-i686-python3-cairo mingw-w64-i686-python3-pip mingw-w64-i686-gtksourceviewmm3 mingw-w64-i686-gcc mingw-w64-i686-boost mingw-w64-i686-python3-numpy mingw-w64-i686-python3-scipy mingw-w64-i686-python3-matplotlib` + * Update MSYS2: `pacman -Syu`, close and restart MSYS2, then run `pacman -Su` to update the rest of the packages. + +Publication +----------- + +If you use DaCe, cite us: +```bibtex +@article{dace, + author = {Ben-Nun, Tal and de Fine Licht, Johannes and Ziogas, Alexandros Nikolaos and Schneider, Timo and Hoefler, Torsten}, + title = {Stateful Dataflow Multigraphs: A Data-Centric Model for High-Performance Parallel Programs}, + journal = {CoRR}, + volume = {abs/1902.10345}, + year = {2019}, + url = {http://arxiv.org/abs/1902.10345}, + archivePrefix = {arXiv}, + eprint = {1902.10345} +} +``` + +Configuration +------------- + +DaCe creates a file called `.dace.conf` in the user's home directory. It provides useful settings that can be modified either directly in the file (YAML), within DIODE, or overriden on a case-by-case basis using environment variables that begin with `DACE_` and specify the setting (where categories are separated by underscores). The full configuration schema is located [here](dace/config_schema.yml). + +Useful environment variable configurations include: + +* `DACE_CONFIG` (default: `~/.dace.conf`): Override DaCe configuration file choice. + +Context configuration: + * `DACE_use_cache` (default: False): Uses DaCe program cache instead of re-optimizing and compiling programs. + * `DACE_debugprint` (default: True): Print debugging information. + +CPU target configuration: + * `DACE_compiler_cpu_executable` (default: g++): Chooses the default C++ compiler for CPU code. + * `DACE_compiler_cpu_additional_args` (default: None): Additional compiler flags (separated by spaces). + +SDFG processing: + * `DACE_optimizer_interface` (default: `dace.transformation.optimizer.SDFGOptimizer`): Controls the SDFG optimization process. If empty or class name is invalid, skips process. By default, uses the transformation command line interface. + * `DACE_optimizer_visualize` (default: False): Visualizes optimization process by saving .dot (GraphViz) files after each pattern replacement. + +Profiling: + * `DACE_profiling` (default: False): Enables profiling measurement of the DaCe program runtime in milliseconds. Produces a log file and prints out median runtime. + * `DACE_treps` (default: 100): Number of repetitions to run a DaCe program when profiling is enabled. + + +Contributing +------------ +DaCe is an open-source project. We are happy to accept Pull Requests with your contributions! + +License +------- +DaCe is published under the New BSD license, see LICENSE. + diff --git a/dace.svg b/dace.svg new file mode 100644 index 0000000000..c744ac88de --- /dev/null +++ b/dace.svg @@ -0,0 +1,84 @@ + + + + + + + image/svg+xml + + + + + + + + + + + + + diff --git a/dace/__init__.py b/dace/__init__.py new file mode 100644 index 0000000000..3e0aa6f8dc --- /dev/null +++ b/dace/__init__.py @@ -0,0 +1,14 @@ +from .types import * + +# Python frontend +from .frontend.python.decorators import * +from .frontend.python.ndarray import * +from .frontend.python.ndloop import ndrange +from .frontend.python.simulator import simulate + +from .config import Config +from .frontend.operations import * +from .sdfg import compile, SDFG, SDFGState +from .memlet import Memlet, EmptyMemlet +from .graph.edges import InterstateEdge +from .symbolic import symbol, eval diff --git a/dace/codegen/CMakeLists.txt b/dace/codegen/CMakeLists.txt new file mode 100644 index 0000000000..916e42eae9 --- /dev/null +++ b/dace/codegen/CMakeLists.txt @@ -0,0 +1,315 @@ +cmake_minimum_required(VERSION 2.8.12) +project(dace_program) + +# General options +set(DACE_PROGRAM_NAME "dace_program" CACHE STRING "Name of DaCe program") +set(DACE_FILES "" CACHE STRING "Host code files") +set(DACE_LIBS "" CACHE STRING "Extra libraries") +set(HLSLIB_PART_NAME "${DACE_XILINX_PART_NAME}") + +# Allow passing flags to various stages of Xilinx compilation process +set(DACE_XILINX_MODE "simulation" CACHE STRING "Type of compilation/execution [simulation/software_emulation/hardware_emulation/hardware].") +set(DACE_XILINX_HOST_FLAGS "" CACHE STRING "Extra flags to host code") +set(DACE_XILINX_SYNTHESIS_FLAGS "" CACHE STRING "Extra flags for performing high-level synthesis") +set(DACE_XILINX_BUILD_FLAGS "" CACHE STRING "Extra flags to xocc build phase") +set(DACE_XILINX_TARGET_CLOCK 200 CACHE STRING "Target clock frequency of FPGA kernel") +set(DACE_XILINX_PART_NAME "xcvu9p-fsgd2104-2l-e" CACHE STRING "Xilinx chip to target from HLS") +set(DACE_XILINX_TARGET_PLATFORM "xilinx_vcu1525_dynamic_5_1" CACHE STRING "SDAccel platform to target") +set(DACE_XILINX_ENABLE_DEBUGGING OFF CACHE STRING "Inject debugging cores to kernel build (always on for simulation/emulation)") + +# Target detection +set(DACE_ENABLE_MPI OFF) +set(DACE_ENABLE_CUDA OFF) +set(DACE_ENABLE_XILINX OFF) + +# Split list by target +foreach(DACE_FILE ${DACE_FILES}) + # Extract the target from the folder name + get_filename_component(DACE_FILE_NAME ${DACE_FILE} NAME_WE) + get_filename_component(DACE_FILE_TARGET ${DACE_FILE} DIRECTORY) + get_filename_component(DACE_FILE_TARGET ${DACE_FILE_TARGET} NAME) + if(${DACE_FILE_TARGET} STREQUAL "cuda") + set(DACE_ENABLE_CUDA ON) + set(DACE_CUDA_FILES ${DACE_CUDA_FILES} ${DACE_FILE}) + elseif(${DACE_FILE_TARGET} STREQUAL "xilinx") + set(DACE_ENABLE_XILINX ON) + if(DACE_FILE_NAME MATCHES ".+_host") + set(DACE_XILINX_HOST_FILES ${DACE_XILINX_HOST_FILES} ${DACE_FILE}) + else() + set(DACE_XILINX_KERNEL_FILES ${DACE_XILINX_KERNEL_FILES} ${DACE_FILE}) + endif() + elseif(${DACE_FILE_TARGET} STREQUAL "mpi") + set(DACE_ENABLE_MPI ON) + set(DACE_CPP_FILES ${DACE_CPP_FILES} ${DACE_FILE}) + else() + set(DACE_CPP_FILES ${DACE_CPP_FILES} ${DACE_FILE}) + endif() +endforeach() + +# Internal dependencies +set(DACE_RUNTIME_DIR ${CMAKE_SOURCE_DIR}/../runtime) +include_directories(${DACE_RUNTIME_DIR}/include) + +# External dependencies +find_package(Threads REQUIRED) +find_package(OpenMP REQUIRED COMPONENTS CXX) +file(TO_NATIVE_PATH "${CMAKE_BINARY_DIR}/" DACE_BINARY_DIR) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS} -DDACE_BINARY_DIR='\"${DACE_BINARY_DIR}\"'") +set(DACE_LIBS ${DACE_LIBS} ${CMAKE_THREAD_LIBS_INIT} ${OpenMP_CXX_LIBRARIES}) +if(DACE_ENABLE_MPI) + find_package(MPI REQUIRED) + include_directories(${MPI_CXX_INCLUDE_PATH}) + set(DACE_LIBS ${DACE_LIBS} ${MPI_CXX_LIBRARIES}) +endif() +if(DACE_ENABLE_CUDA) + find_package(CUDA REQUIRED) + set(CUDA_PROPAGATE_HOST_FLAGS OFF) + include_directories(${CUDA_INCLUDE_DIRS}) + set(DACE_LIBS ${DACE_LIBS} ${CUDA_LIBRARIES}) + add_definitions(-DWITH_CUDA) +endif() +if(DACE_ENABLE_XILINX) + set(DACE_HLSLIB_DIR ${CMAKE_SOURCE_DIR}/../external/hlslib) + set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${DACE_HLSLIB_DIR}/cmake) + find_package(SDAccel REQUIRED) + + include_directories(SYSTEM ${SDAccel_INCLUDE_DIRS} ${DACE_HLSLIB_DIR}/include) + add_definitions(-DDACE_XILINX) + set(DACE_LIBS ${DACE_LIBS} ${SDAccel_LIBRARIES}) + +endif() + +# Create CUDA object files +if(DACE_ENABLE_CUDA) + # Get local CUDA architectures + if (NOT DEFINED LOCAL_CUDA_ARCHITECTURES) + execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "--run" + "${CMAKE_SOURCE_DIR}/tools/get_cuda_arch.cpp" + OUTPUT_VARIABLE _arch_out RESULT_VARIABLE _arch_res + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(_arch_res EQUAL 0) + string(REGEX REPLACE "\n" ";" _arch_out "${_arch_out}") + list(GET _arch_out -1 _local_arch) + string(REGEX REPLACE " " ";" _local_arch "${_local_arch}") + set(LOCAL_CUDA_ARCHITECTURES "${_local_arch}" CACHE STRING "Detected local GPUs for compilation") + message("-- Local CUDA architectures detected: ${LOCAL_CUDA_ARCHITECTURES}") + else() + set(LOCAL_CUDA_ARCHITECTURES "" CACHE STRING "Detected local GPUs for compilation") + message("-- No local CUDA-capable GPUs found") + endif() + endif() + + # Add flags to compile for local CUDA architectures + foreach(var ${LOCAL_CUDA_ARCHITECTURES}) + list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_${var},code=sm_${var}) + endforeach() + + cuda_include_directories(${DACE_RUNTIME_DIR}/include) + cuda_compile(DACE_CUDA_OBJECTS ${DACE_CUDA_FILES}) + set(DACE_OBJECTS ${DACE_OBJECTS} ${DACE_CUDA_OBJECTS}) +endif() # DACE_ENABLE_CUDA + +# Create Xilinx object files +if(DACE_ENABLE_XILINX) + if((NOT (DACE_XILINX_MODE STREQUAL "hardware")) OR DACE_XILINX_ENABLE_DEBUGGING) + set(DACE_XILINX_HOST_FLAGS "${DACE_XILINX_HOST_FLAGS} -g") + set(DACE_XILINX_SYNTHESIS_FLAGS "${DACE_XILINX_SYNTHESIS_FLAGS} -g") + endif() + + set_source_files_properties(${DACE_XILINX_KERNEL_FILES} ${DACE_XILINX_HOST_FILES} PROPERTIES COMPILE_FLAGS "${DACE_XILINX_HOST_FLAGS}") + set_source_files_properties(${DACE_XILINX_KERNEL_FILES} PROPERTIES COMPILE_FLAGS "-DDACE_XILINX_DEVICE_CODE ${DACE_XILINX_HOST_FLAGS}") + set(DACE_OBJECTS ${DACE_OBJECTS} ${DACE_XILINX_KERNEL_FILES} ${DACE_XILINX_HOST_FILES}) + + if(((${SDAccel_MAJOR_VERSION} LESS 2018) AND + (${SDAccel_MINOR_VERSION} LESS 3)) OR ${SDAccel_MAJOR_VERSION} LESS 2017) + add_definitions(-DHLSLIB_LEGACY_SDX=1) + else() + add_definitions(-DHLSLIB_LEGACY_SDX=0) + endif() + + if(DACE_XILINX_MODE STREQUAL "simulation") + # This will cause the OpenCL calls to instead call a simulation code + # running on the host + add_definitions(-DHLSLIB_SIMULATE_OPENCL) + endif() + + set(DACE_XILINX_SYNTHESIS_FLAGS "${DACE_XILINX_SYNTHESIS_FLAGS} -DDACE_SYNTHESIS -DDACE_XILINX -DDACE_XILINX_DEVICE_CODE -DHLSLIB_SYNTHESIS -std=c++11") + + # Add synthesis and build commands + set(DACE_SYNTHESIS_TARGETS) + foreach(DACE_KERNEL_FILE ${DACE_XILINX_KERNEL_FILES}) + get_filename_component(DACE_KERNEL_NAME ${DACE_KERNEL_FILE} NAME) + string(REGEX REPLACE "kernel_(.+).cpp" "\\1" DACE_KERNEL_NAME "${DACE_KERNEL_NAME}") + string(REPLACE " " ";" DACE_XILINX_SYNTHESIS_FLAGS_INTERNAL ${DACE_XILINX_SYNTHESIS_FLAGS}) + set(DACE_XOCC_KERNEL_FILES ${DACE_XOCC_KERNEL_FILES} ${DACE_KERNEL_FILE}) + set(DACE_XOCC_KERNELS ${DACE_XOCC_KERNELS} --kernel ${DACE_KERNEL_NAME} --xp prop:kernel.${DACE_KERNEL_NAME}.kernel_flags=\"${DACE_XILINX_SYNTHESIS_FLAGS_INTERNAL}\") + + configure_file(${CMAKE_SOURCE_DIR}/Xilinx_HLS.tcl.in Synthesize_${DACE_KERNEL_NAME}.tcl) + add_custom_target(xilinx_synthesis_${DACE_KERNEL_NAME} COMMAND ${SDAccel_VIVADO_HLS} -f Synthesize_${DACE_KERNEL_NAME}.tcl) + set(DACE_SYNTHESIS_TARGETS ${DACE_SYNTHESIS_TARGETS} xilinx_synthesis_${DACE_KERNEL_NAME}) + + endforeach() + + add_custom_target(xilinx_synthesis DEPENDS ${DACE_SYNTHESIS_TARGETS}) + + string(REPLACE " " ";" DACE_XILINX_BUILD_FLAGS_INTERNAL + "${DACE_XILINX_BUILD_FLAGS}") + + set(XOCC_BUILD_FLAGS + -s + -O3 + -I${CMAKE_SOURCE_DIR}/include + -I${CMAKE_SOURCE_DIR}/../external/hlslib/include + -I${CMAKE_SOURCE_DIR}/../runtime/include + -I${CMAKE_BINARY_DIR} + "${DACE_XOCC_KERNELS}" + --platform ${DACE_XILINX_TARGET_PLATFORM} + ${DACE_XILINX_BUILD_FLAGS_INTERNAL} + --kernel_frequency ${DACE_XILINX_TARGET_CLOCK} + --max_memory_ports all) + + if((NOT (DACE_XILINX_MODE STREQUAL "hardware")) OR DACE_XILINX_ENABLE_DEBUGGING) + # TODO: add Chipscope debugging on memory interfaces. Need to pass + # interfaces from codegen to CMake in order to do this. + message(STATUS "Enabled debugging/profiling for Xilinx targets.") + set(XOCC_BUILD_FLAGS ${XOCC_BUILD_FLAGS} + --profile_kernel "data:all:all:all" + --profile_kernel "stall:all:all" + --profile_kernel "exec:all:all") + endif() + + if(SDAccel_MAJOR_VERSION LESS 2018 AND SDAccel_MINOR_VERSION LESS 3) + + add_custom_target( + xilinx_build_${DACE_PROGRAM_NAME}_software_emulation + COMMAND + XILINX_PATH=${CMAKE_BINARY_DIR} ${SDAccel_XOCC} + ${XOCC_BUILD_FLAGS} + -t sw_emu + ${DACE_XOCC_KERNEL_FILES} + -o ${DACE_PROGRAM_NAME}_sw_emu.xclbin) + + add_custom_target( + xilinx_build_${DACE_PROGRAM_NAME}_hardware_emulation + COMMAND + XILINX_PATH=${CMAKE_BINARY_DIR} ${SDAccel_XOCC} + ${XOCC_BUILD_FLAGS} + -t hw_emu + ${DACE_XOCC_KERNEL_FILES} + -o ${DACE_PROGRAM_NAME}_hw_emu.xclbin) + + add_custom_target( + xilinx_build_${DACE_PROGRAM_NAME}_hardware + COMMAND + XILINX_PATH=${CMAKE_BINARY_DIR} ${SDAccel_XOCC} + ${XOCC_BUILD_FLAGS} + -t hw + ${DACE_XOCC_KERNEL_FILES} + -o ${DACE_PROGRAM_NAME}_hw.xclbin) + + else() + + add_custom_target( + xilinx_compile_${DACE_PROGRAM_NAME}_software_emulation + COMMAND + XILINX_PATH=${CMAKE_BINARY_DIR} ${SDAccel_XOCC} + ${XOCC_BUILD_FLAGS} + -c + -t sw_emu + ${DACE_XOCC_KERNEL_FILES} + -o ${DACE_PROGRAM_NAME}_sw_emu.xo) + + add_custom_target( + xilinx_compile_${DACE_PROGRAM_NAME}_hardware_emulation + COMMAND + XILINX_PATH=${CMAKE_BINARY_DIR} ${SDAccel_XOCC} + ${XOCC_BUILD_FLAGS} + -c + -t hw_emu + ${DACE_XOCC_KERNEL_FILES} + -o ${DACE_PROGRAM_NAME}_hw_emu.xo) + + add_custom_target( + xilinx_compile_${DACE_PROGRAM_NAME}_hardware + COMMAND + XILINX_PATH=${CMAKE_BINARY_DIR} ${SDAccel_XOCC} + ${XOCC_BUILD_FLAGS} + -c + -t hw + ${DACE_XOCC_KERNEL_FILES} + -o ${DACE_PROGRAM_NAME}_hw.xo) + + add_custom_target( + xilinx_build_${DACE_PROGRAM_NAME}_software_emulation + COMMAND + XILINX_PATH=${CMAKE_BINARY_DIR} ${SDAccel_XOCC} + ${XOCC_BUILD_FLAGS} + -l + -t sw_emu + ${DACE_PROGRAM_NAME}_sw_emu.xo + -o ${DACE_PROGRAM_NAME}_sw_emu.xclbin) + + add_custom_target( + xilinx_build_${DACE_PROGRAM_NAME}_hardware_emulation + COMMAND + XILINX_PATH=${CMAKE_BINARY_DIR} ${SDAccel_XOCC} + ${XOCC_BUILD_FLAGS} + -l + -t hw_emu + ${DACE_PROGRAM_NAME}_hw_emu.xo + -o ${DACE_PROGRAM_NAME}_hw_emu.xclbin) + + add_custom_target( + xilinx_build_${DACE_PROGRAM_NAME}_hardware + COMMAND + XILINX_PATH=${CMAKE_BINARY_DIR} ${SDAccel_XOCC} + ${XOCC_BUILD_FLAGS} + -l + -t hw + ${DACE_PROGRAM_NAME}_hw.xo + -o ${DACE_PROGRAM_NAME}_hw.xclbin) + + endif() + +endif() # DACE_ENABLE_XILINX + +# Create DaCe library file +add_library(${DACE_PROGRAM_NAME} SHARED ${DACE_CPP_FILES} ${DACE_OBJECTS}) +target_link_libraries(${DACE_PROGRAM_NAME} ${DACE_LIBS}) + +# Create DaCe loader stub +add_library(dacestub_${DACE_PROGRAM_NAME} SHARED "${CMAKE_SOURCE_DIR}/tools/dacestub.cpp") +target_link_libraries(dacestub_${DACE_PROGRAM_NAME} ${CMAKE_THREAD_LIBS_INIT} ${OpenMP_CXX_LIBRARIES}) + +# Windows-specific fixes +if (MSVC_IDE) + # Copy output DLL from the "Debug" and "Release" directories CMake adds + # NOTE: The "|| (exit 0)" is added because copy sometimes fails due to the + # stub library being already loaded. + add_custom_target(CopyDLL ALL + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ "${CMAKE_BINARY_DIR}/lib${DACE_PROGRAM_NAME}.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ "${CMAKE_BINARY_DIR}/libdacestub_${DACE_PROGRAM_NAME}.dll" || (exit 0) + DEPENDS ${DACE_PROGRAM_NAME} + COMMENT "Copying binaries" VERBATIM) + + # Replace /MD with /MT so that CUDA links properly + # https://stackoverflow.com/a/14172871/6489142 + set(CompilerFlags + CMAKE_CXX_FLAGS + CMAKE_CXX_FLAGS_DEBUG + CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_RELWITHDEBINFO + CMAKE_CXX_FLAGS_MINSIZEREL + CMAKE_C_FLAGS + CMAKE_C_FLAGS_DEBUG + CMAKE_C_FLAGS_RELEASE + CMAKE_C_FLAGS_RELWITHDEBINFO + CMAKE_C_FLAGS_MINSIZEREL + ) + foreach(CompilerFlag ${CompilerFlags}) + string(REPLACE "/MD" "/MT" ${CompilerFlag} "${${CompilerFlag}}") + endforeach() +endif() diff --git a/dace/codegen/Xilinx_HLS.tcl.in b/dace/codegen/Xilinx_HLS.tcl.in new file mode 100644 index 0000000000..7261b513a9 --- /dev/null +++ b/dace/codegen/Xilinx_HLS.tcl.in @@ -0,0 +1,14 @@ +open_project ${DACE_KERNEL_NAME} +open_solution ${DACE_XILINX_PART_NAME} +add_files -cflags "${DACE_XILINX_SYNTHESIS_FLAGS} -I${DACE_RUNTIME_DIR}/include -I${DACE_HLSLIB_DIR}/include -I${CMAKE_BINARY_DIR}" "${DACE_KERNEL_FILE}" +set_top ${DACE_KERNEL_NAME} +set_part ${DACE_XILINX_PART_NAME} +create_clock -period ${DACE_XILINX_TARGET_CLOCK}MHz -name default +# SDAccel default options +config_rtl -register_reset +config_interface -m_axi_addr64 +config_schedule -relax_ii_for_timing +config_compile -pipeline_loops 64 +config_compile -name_max_length 256 +csynth_design +exit diff --git a/dace/codegen/__init__.py b/dace/codegen/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py new file mode 100644 index 0000000000..406b0409cf --- /dev/null +++ b/dace/codegen/codegen.py @@ -0,0 +1,67 @@ +import numpy as np + +from typing import List + +from dace import symbolic +from dace.codegen.targets import framecode +from dace.codegen.codeobject import CodeObject + +from dace.codegen.instrumentation.perfsettings import PerfSettings, PerfMetaInfoStatic, PerfMetaInfo + +# Import all code generation targets +from dace.codegen.targets import cpu, cuda, immaterial, mpi, xilinx + + +class CodegenError(Exception): + pass + + +STRING_TO_TARGET = { + "cpu": cpu.CPUCodeGen, + "cuda": cuda.CUDACodeGen, + "immaterial": immaterial.ImmaterialCodeGen, + "mpi": mpi.MPICodeGen, + "xilinx": xilinx.XilinxCodeGen, +} + +_TARGET_REGISTER_ORDER = ['cpu', 'cuda', 'immaterial', 'mpi', 'xilinx'] + + +def generate_code(sdfg) -> List[CodeObject]: + """ Generates code as a list of code objects for a given SDFG. + @param sdfg: The SDFG to use + @return: List of code objects that correspond to files to compile. + """ + # Before compiling, validate SDFG correctness + sdfg.validate() + + frame = framecode.DaCeCodeGenerator() + # Instantiate all targets (who register themselves with framecodegen) + targets = { + name: STRING_TO_TARGET[name](frame, sdfg) + for name in _TARGET_REGISTER_ORDER + } + + # Generate frame code (and the rest of the code) + global_code, frame_code, used_targets = frame.generate_code(sdfg, None) + target_objects = [ + CodeObject( + sdfg.name, + global_code + frame_code, + 'cpp', + cpu.CPUCodeGen, + 'Frame', + meta_info=PerfMetaInfoStatic.info + if PerfSettings.perf_enable_vectorization_analysis() else + PerfMetaInfo()) + ] + PerfMetaInfoStatic.info = PerfMetaInfo() + + # Create code objects for each target + for tgt in used_targets: + target_objects.extend(tgt.get_generated_codeobjects()) + + return target_objects + + +################################################################## diff --git a/dace/codegen/codeobject.py b/dace/codegen/codeobject.py new file mode 100644 index 0000000000..cb25c3884a --- /dev/null +++ b/dace/codegen/codeobject.py @@ -0,0 +1,51 @@ +import ctypes +import numpy as np + +from dace import symbolic, types +from dace.config import Config +from dace.frontend import operations +from dace.properties import Property, make_properties +from dace.codegen.targets.target import TargetCodeGenerator + +from dace.codegen.instrumentation.perfsettings import PerfMetaInfo + + +@make_properties +class CodeObject(object): + name = Property(dtype=str, desc="Filename to use") + code = Property(dtype=str, desc="The code attached to this object") + perf_meta_info = Property( + dtype=PerfMetaInfo, desc="Meta information used to map nodes to LOC") + language = Property( + dtype=str, + desc="Language used for this code (same " + + "as its file extension)") # dtype=types.Language? + target = Property(dtype=type, desc="Target to use for compilation") + title = Property(dtype=str, desc="Title of code for GUI") + extra_compiler_kwargs = Property( + dtype=dict, + desc="Additional compiler argument " + "variables to add to template") + linkable = Property( + dtype=bool, desc='Should this file participate in ' + 'overall linkage?') + + def __init__(self, + name, + code, + language, + target, + title, + additional_compiler_kwargs={}, + linkable=True, + meta_info=PerfMetaInfo()): + super(CodeObject, self).__init__() + + self.name = name + self.code = code + self.language = language + self.target = target + self.title = title + self.extra_compiler_kwargs = additional_compiler_kwargs + self.linkable = linkable + self.perf_meta_info = meta_info diff --git a/dace/codegen/compiler.py b/dace/codegen/compiler.py new file mode 100644 index 0000000000..16dab734b2 --- /dev/null +++ b/dace/codegen/compiler.py @@ -0,0 +1,512 @@ +#!/usr/bin/python3 +""" Handles compilation of code objects. Creates the proper folder structure, + compiles each target separately, links all targets to one binary, and + returns the corresponding CompiledSDFG object. """ + +from __future__ import print_function + +import ctypes +import os +import re +import six +import shutil +import subprocess +import string +import subprocess as sp +import re +from typing import List +import numpy as np + +import dace +from dace.frontend import operations +from dace.frontend.python import ndarray +from dace import symbolic, types +from dace.config import Config +from dace.codegen import codegen +from dace.codegen.codeobject import CodeObject +from dace.codegen.targets.cpu import CPUCodeGen +from dace.codegen.targets.target import make_absolute + +from dace.codegen.instrumentation.perfsettings import PerfSettings, PerfMetaInfoStatic + + +# Specialized exception classes +class DuplicateDLLError(Exception): + """ An exception that is raised whenever a library is loaded twice. """ + pass + + +class CompilerConfigurationError(Exception): + """ An exception that is raised whenever CMake encounters a configuration + error. """ + pass + + +class CompilationError(Exception): + """ An exception that is raised whenever a compilation error occurs. """ + pass + + +class ReloadableDLL(object): + """ A reloadable shared object (or dynamically linked library), which + bypasses Python's dynamic library reloading issues. """ + + def __init__(self, library_filename, program_name): + """ Creates a new reloadable shared object. + @param library_filename: Path to library file. + @param program_name: Name of the DaCe program (for use in finding + the stub library loader). + """ + self._stub_filename = os.path.join( + os.path.dirname(os.path.realpath(library_filename)), + 'libdacestub_%s.%s' % + (program_name, Config.get('compiler', 'library_extension'))) + self._library_filename = library_filename + self._stub = None + self._lib = None + + def get_symbol(self, name, restype=ctypes.c_int): + """ Returns a symbol (e.g., function name) in the loaded library. """ + + if self._lib is None or self._lib.value is None: + raise ReferenceError( + 'ReloadableDLL can only be used with a ' + + '"with" statement or with load() and unload()') + + func = self._stub.get_symbol(self._lib, ctypes.c_char_p(name.encode())) + if func is None: + raise KeyError('Function %s not found in library %s' % + (name, os.path.basename(self._library_filename))) + + return ctypes.CFUNCTYPE(restype)(func) + + def load(self): + """ Loads the internal library using the stub. """ + + # If internal library is already loaded, skip + if self._lib is not None and self._lib.value is not None: + return + self._stub = ctypes.CDLL(self._stub_filename) + + # Set return types of stub functions + self._stub.load_library.restype = ctypes.c_void_p + self._stub.get_symbol.restype = ctypes.c_void_p + + # Convert library filename to string according to OS + if os.name == 'nt': + # As UTF-16 + lib_cfilename = ctypes.c_wchar_p(self._library_filename) + else: + # As UTF-8 + lib_cfilename = ctypes.c_char_p( + self._library_filename.encode('utf-8')) + + # Check if library is already loaded + is_loaded = self._stub.is_library_loaded(lib_cfilename) + if is_loaded == 1: + raise DuplicateDLLError( + 'Library %s is already loaded somewhere else, ' % + os.path.basename(self._library_filename) + + 'either unload it or use a different name ' + + 'for the SDFG/program.') + + # Actually load the library + self._lib = ctypes.c_void_p(self._stub.load_library(lib_cfilename)) + + if self._lib.value is None: + raise RuntimeError('Could not load library %s' % os.path.basename( + self._library_filename)) + + def unload(self): + """ Unloads the internal library using the stub. """ + + if self._stub is None: + return + + self._stub.unload_library(self._lib) + self._lib = None + del self._stub + self._stub = None + + def __enter__(self, *args, **kwargs): + self.load() + return self + + def __exit__(self, *args, **kwargs): + self.unload() + + +class CompiledSDFG(object): + """ A compiled SDFG object that can be called through Python. """ + + def __init__(self, sdfg, lib: ReloadableDLL): + self._sdfg = sdfg + self._lib = lib + self._initialized = False + self._lastargs = () + lib.load() # Explicitly load the library + self._init = lib.get_symbol('__dace_init') + self._exit = lib.get_symbol('__dace_exit') + self._cfunc = lib.get_symbol('__program_{}'.format(sdfg.name)) + + @property + def sdfg(self): + return self._sdfg + + def __del__(self): + if self._initialized == True: + self.finalize(*self._lastargs) + self._initialized = False + self._lib.unload() + + def _construct_args(self, *args, **kwargs): + """ Main function that controls argument construction for calling + the C prototype of the SDFG. + + Organizes arguments first by `sdfg.arglist`, then data descriptors + by alphabetical order, then symbols by alphabetical order. + """ + + if len(kwargs) > 0 and len(args) > 0: + raise AttributeError( + 'Compiled SDFGs can only be called with either arguments ' + + '(e.g. "program(a,b,c)") or keyword arguments ' + + '("program(A=a,B=b)"), but not both') + + # Argument construction + sig = [] + if len(kwargs) > 0: + # Construct mapping from arguments to signature + sig = self._sdfg.signature_arglist(with_types=False) + arglist = [] + for a in sig: + try: + arglist.append(kwargs[a]) + except KeyError: + raise KeyError("Missing kernel argument \"{}\"".format(a)) + elif len(args) > 0: + arglist = list(args) + else: + arglist = [] + + sdfg = self._sdfg + + # As in compilation, add symbols used in array sizes to parameters + symparams = {} + for symname in sdfg.undefined_symbols(False): + # Ignore arguments (as they may not be symbols but constants, + # see below) + if symname in sdfg.arg_types: continue + try: + symval = symbolic.symbol(symname) + symparams[symname] = symval.get() + except UnboundLocalError: + try: + symparams[symname] = kwargs[symname] + except KeyError: + raise UnboundLocalError('Unassigned symbol %s' % symname) + + arglist.extend( + [symparams[k] for k in sorted(symparams.keys()) if k not in sig]) + + # Obtain SDFG constants + constants = sdfg.constants + + # Remove symbolic constants from arguments + callparams = tuple( + arg for arg in arglist if not symbolic.issymbolic(arg) or ( + hasattr(arg, 'name') and arg.name not in constants)) + + # Replace symbols with their values + callparams = tuple( + symbolic.eval(arg) if symbolic.issymbolic(arg, constants) else arg + for arg in callparams) + + # Replace arrays with their pointers + newargs = tuple( + ctypes.c_void_p(arg.__array_interface__['data'][0]) if (isinstance( + arg, ndarray.ndarray) or isinstance(arg, np.ndarray)) else arg + for arg in callparams) + + newargs = tuple(types._FFI_CTYPES[type(arg)](arg) + if type(arg) in types._FFI_CTYPES else arg + for arg in newargs) + + self._lastargs = newargs + return self._lastargs + + def initialize(self, *argtuple): + if self._init is not None: + res = self._init(*argtuple) + if res != 0: + raise RuntimeError('DaCe application failed to initialize') + + self._initialized = True + + def finalize(self, *argtuple): + if self._exit is not None: + self._exit(*argtuple) + + def __call__(self, *args, **kwargs): + argtuple = self._construct_args(*args, **kwargs) + + # Call initializer function if necessary, then SDFG + if self._initialized == False: + self.initialize(*argtuple) + + # PROFILING + if Config.get_bool('profiling'): + operations.timethis(self._sdfg.name, 'DaCe', 0, self._cfunc, + *argtuple) + else: + return self._cfunc(*argtuple) + + +def unique_flags(flags): + pattern = '[^ ]+[`\'"][^"\'`]+["\'`]|[^ ]+' + if not isinstance(flags, str): + flags = " ".join(flags) + return set(re.findall(pattern, flags)) + + +def generate_program_folder(code_objects: List[CodeObject], out_path): + """ Writes all files required to configure and compile the DaCe program + into the specified folder. + + @param code_objects: List of generated code objects. + @param out_path: The folder in which the build files should be written. + @return: Path to the program folder. + """ + + src_path = os.path.join(out_path, "src") + + try: + os.makedirs(src_path) + except FileExistsError: + pass + + filelist = [] + # Write each code object to a file + for code_object in code_objects: + + name = code_object.name + extension = code_object.language + target_name = code_object.target.target_name + + # Create target folder + target_folder = os.path.join(src_path, target_name) + try: + os.makedirs(target_folder) + except FileExistsError: + pass + + # Write code to file + basename = "{}.{}".format(name, extension) + code_path = os.path.join(target_folder, basename) + with open(code_path, "w") as code_file: + clean_code = re.sub(r'[ \t]*////__DACE:[^\n]*', '', + code_object.code) + + if PerfSettings.perf_enable_vectorization_analysis(): + # Generate line number information from the code + # TODO: Make per code stream + code_object.perf_meta_info.resolve(clean_code) + code_file.write(clean_code) + + filelist.append("{},{}".format(target_name, basename)) + + # Write list of files + with open(os.path.join(out_path, "dace_files.csv"), "w") as filelist_file: + filelist_file.write("\n".join(filelist)) + + # Copy snapshot of configuration script + Config.save(os.path.join(out_path, "dace.conf")) + + return out_path + + +def configure_and_compile(program_folder, program_name=None): + """ Configures and compiles a DaCe program in the specified folder into a + shared library file. + + @param program_folder: Folder containing all files necessary to build, + equivalent to what was passed to + `generate_program_folder`. + @return: Path to the compiled shared library file. + """ + + if program_name is None: + program_name = os.path.basename(program_folder) + program_folder = os.path.abspath(program_folder) + src_folder = os.path.join(program_folder, "src") + + # Prepare build folder + build_folder = os.path.join(program_folder, "build") + try: + os.makedirs(build_folder) + except FileExistsError: + pass + + # Read list of DaCe files to compile. + # We do this instead of iterating over source files in the directory to + # avoid globbing files from previous compilations, such that we don't need + # to wipe the directory for every compilation. + file_list = [ + line.strip().split(",") + for line in open(os.path.join(program_folder, "dace_files.csv"), "r") + ] + + # Get absolute paths and targets for all source files + files = [] + targets = {} # {target name: target class} + for target_name, file_name in file_list: + path = os.path.join(src_folder, target_name, file_name) + files.append(path) + targets[target_name] = codegen.STRING_TO_TARGET[target_name] + + # Start forming CMake command + dace_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + cmake_command = [ + "cmake", + "-A x64" if os.name == 'nt' else "", # Windows-specific flag + '"' + os.path.join(dace_path, "codegen") + '"', + "-DDACE_FILES=\"{}\"".format(";".join(files)), + "-DDACE_PROGRAM_NAME={}".format(program_name), + ] + + # Replace backslashes with forward slashes + cmake_command = [cmd.replace('\\', '/') for cmd in cmake_command] + + # Generate CMake options for each compiler + libraries = set() + for target_name, target in targets.items(): + cmake_command += target.cmake_options() + try: + libraries |= unique_flags( + Config.get("compiler", target_name, "libs")) + except KeyError: + pass + + # TODO: it should be possible to use the default arguments/compilers + # found by CMake + cmake_command += [ + "-DDACE_LIBS=\"{}\"".format(" ".join(libraries)), + "-DCMAKE_LINKER=\"{}\"".format( + make_absolute(Config.get('compiler', 'linker', 'executable'))), + "-DCMAKE_SHARED_LINKER_FLAGS=\"{}\"".format( + Config.get('compiler', 'linker', 'args') + + Config.get('compiler', 'linker', 'additional_args')), + ] + + ############################################## + # Configure + try: + _run_liveoutput(" ".join(cmake_command), shell=True, cwd=build_folder) + except subprocess.CalledProcessError as ex: + # Clean CMake directory and try once more + if Config.get_bool('debugprint'): + print('Cleaning CMake build folder and retrying...') + shutil.rmtree(build_folder) + os.makedirs(build_folder) + try: + _run_liveoutput( + " ".join(cmake_command), shell=True, cwd=build_folder) + except subprocess.CalledProcessError as ex: + # If still unsuccessful, print results + if Config.get_bool('debugprint'): + raise CompilerConfigurationError('Configuration failure') + else: + raise CompilerConfigurationError('Configuration failure:\n' + + ex.output) + + # Compile and link + try: + _run_liveoutput( + "cmake --build . --config %s" % (Config.get( + 'compiler', 'build_type')), + shell=True, + cwd=build_folder) + except subprocess.CalledProcessError as ex: + # If unsuccessful, print results + if Config.get_bool('debugprint'): + raise CompilationError('Compiler failure') + else: + raise CompilationError('Compiler failure:\n' + ex.output) + + shared_library_path = os.path.join( + build_folder, "lib{}.{}".format( + program_name, Config.get('compiler', 'library_extension'))) + + return shared_library_path + + +def get_program_handle(library_path, sdfg): + lib = ReloadableDLL(library_path, sdfg.name) + # Load and return the compiled function + return CompiledSDFG(sdfg, lib) + + +def load_from_file(sdfg, binary_filename): + if not os.path.isfile(binary_filename): + raise FileNotFoundError('File not found: ' + binary_filename) + + # Load the generated library + lib = ReloadableDLL(binary_filename, sdfg.name) + + # Load and return the compiled function + return CompiledSDFG(sdfg, lib) + + +def get_binary_name(object_name, + object_hash=None, + lib_extension=Config.get('compiler', 'library_extension')): + name = None + if object_hash is None: + name = os.path.join('.dacecache', object_name, "build", + 'lib%s.%s' % (object_name, lib_extension)) + else: + name = os.path.join( + '.dacecache', object_name, "build", + 'lib%s_%s.%s' % (object_name, object_hash, lib_extension)) + return name + + +def _run_liveoutput(command, **kwargs): + process = subprocess.Popen( + command, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, **kwargs) + output = six.StringIO() + while True: + line = process.stdout.readline().rstrip() + if not line: + break + output.write(line.decode('utf-8') + '\n') + if Config.get_bool('debugprint'): + print(line.decode('utf-8'), flush=True) + stdout, stderr = process.communicate() + if Config.get_bool('debugprint'): + print(stdout.decode('utf-8'), flush=True) + if stderr is not None: + print(stderr.decode('utf-8'), flush=True) + output.write(stdout.decode('utf-8')) + if stderr is not None: + output.write(stderr.decode('utf-8')) + + # An error occurred, raise exception + if process.returncode != 0: + raise subprocess.CalledProcessError(process.returncode, command, + output.getvalue()) + + +# Allow configuring and compiling a prepared build folder from the commandline. +# This is useful for remote execution. +if __name__ == "__main__": + import argparse + + argparser = argparse.ArgumentParser() + argparser.add_argument("path", type=str) + argparser.add_argument("outname", type=str) + args = vars(argparser.parse_args()) + + Config.load(os.path.join(args["path"], "dace.conf")) + + configure_and_compile(args["path"], args["outname"]) diff --git a/dace/codegen/cppunparse.py b/dace/codegen/cppunparse.py new file mode 100644 index 0000000000..904057a62b --- /dev/null +++ b/dace/codegen/cppunparse.py @@ -0,0 +1,1093 @@ +# This module is derived from astunparse: https://github.com/simonpercivall/astunparse +########################################################################## +### astunparse LICENSES +# LICENSE +# ================== +# +# Copyright (c) 2014, Simon Percivall +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +# +# * Neither the name of AST Unparser nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# +# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# -------------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013, 2014 Python Software Foundation; All Rights Reserved" are retained +# in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +########################################################################## +### END OF astunparse LICENSES + +from __future__ import print_function, unicode_literals +import inspect +import six +import sys +import ast +import os +import tokenize +from six import StringIO + +# Large float and imaginary literals get turned into infinities in the AST. +# We unparse those infinities to INFSTR. +INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) + +_py2c_nameconst = {True: "true", False: "false", None: "nullptr"} + +_py2c_reserved = {"True": "true", "False": "false", "None": "nullptr"} + + +def interleave(inter, f, seq): + """Call f on each item in seq, calling inter() in between. + """ + seq = iter(seq) + try: + f(next(seq)) + except StopIteration: + pass + else: + for x in seq: + inter() + f(x) + + +class LocalScheme(object): + def is_defined(self, local_name, current_depth): + raise NotImplementedError('Abstract class') + + def define(self, local_name, lineno, depth): + raise NotImplementedError('Abstract class') + + def clear_scope(self, from_indentation): + raise NotImplementedError('Abstract class') + + +class CPPLocals(LocalScheme): + def __init__(self): + # Maps local name to a 2-tuple of line number and scope (measured in indentation) + self.locals = {} + + def is_defined(self, local_name, current_depth): + return local_name in self.locals + + def define(self, local_name, lineno, depth): + self.locals[local_name] = (lineno, depth) + + def clear_scope(self, from_indentation): + """Clears all locals defined in indentation 'from_indentation' and deeper""" + toremove = set() + for local_name, (lineno, depth) in self.locals.items(): + if depth >= from_indentation: + toremove.add(local_name) + + for var in toremove: + del self.locals[var] + + +# Python scheme: All global variables can be read, but not written to (unless defined as "global") +class PythonLocals(LocalScheme): + def __init__(self): + # Maps local name to a 2-tuple of line number and scope (measured in indentation) + self.locals = {} + + def is_defined(self, local_name, current_depth): + return local_name in self.locals and self.locals[local_name][1] == current_depth + + def define(self, local_name, lineno, depth): + self.locals[local_name] = (lineno, depth) + + def clear_scope(self, from_indentation): + """Clears all locals defined in indentation 'from_indentation' and deeper""" + toremove = set() + for local_name, (lineno, depth) in self.locals.items(): + if depth >= from_indentation: + toremove.add(local_name) + for var in toremove: + del self.locals[var] + + +class CPPUnparser: + """Methods in this class recursively traverse an AST and + output C++ source code for the abstract syntax; original formatting + is disregarded. """ + + def __init__(self, + tree, + depth, + locals, + file=sys.stdout, + indent_output=True, + expr_semicolon=True, + indent_offset=0): + + self.f = file + self.future_imports = [] + self._indent = depth + self.indent_output = indent_output + self.indent_offset = indent_offset + self.expr_semicolon = expr_semicolon + if not isinstance(locals, LocalScheme): + raise TypeError('Locals must be a LocalScheme object') + self.locals = locals + self.firstfill = True + + self.dispatch(tree) + print("", file=self.f) + self.f.flush() + + def fill(self, text=""): + """Indent a piece of text, according to the current indentation level""" + if self.firstfill: + if self.indent_output: + self.f.write(" " * (self._indent + self.indent_offset) + + text) + else: + self.f.write(text) + self.firstfill = False + else: + if self.indent_output: + self.f.write("\n" + " " * + (self._indent + self.indent_offset) + text) + else: + self.f.write("\n" + text) + + def write(self, text): + """Append a piece of text to the current line.""" + self.f.write(six.text_type(text)) + + def enter(self): + """Print '{', and increase the indentation.""" + self.write(" {") + self._indent += 1 + + def leave(self): + """Decrease the indentation and print '}'.""" + self._indent -= 1 + self.fill() + self.write("}") + # Clear locals defined inside scope + self.locals.clear_scope(self._indent + 1) + + def dispatch(self, tree): + """Dispatcher function, dispatching tree type T to method _T.""" + try: + tree = iter(tree) + for t in tree: + self.dispatch(t) + except TypeError: + meth = getattr(self, "_" + tree.__class__.__name__) + meth(tree) + + ############### Unparsing methods ###################### + # There should be one method per concrete grammar type # + # Constructors should be grouped by sum type. Ideally, # + # this would follow the order in the grammar, but # + # currently doesn't. # + ######################################################## + + def _Module(self, tree): + for stmt in tree.body: + self.dispatch(stmt) + + def _Interactive(self, tree): + for stmt in tree.body: + self.dispatch(stmt) + + def _Expression(self, tree): + self.dispatch(tree.body) + + # stmt + def _Expr(self, tree): + self.fill() + self.dispatch(tree.value) + if self.expr_semicolon: + self.write(';') + + def _Import(self, t): + raise SyntaxError('Invalid C++') + + def _ImportFrom(self, t): + raise SyntaxError('Invalid C++') + + def dispatch_lhs_tuple(self, targets): + # Decide whether to use the C++17 syntax for undefined variables or std::tie for defined variables + if all( + self.locals.is_defined(target.id, self._indent) + for target in targets): + defined = True + elif any( + self.locals.is_defined(target.id, self._indent) + for target in targets): + raise SyntaxError( + 'Invalid C++ (some variables in tuple were already defined)') + else: + defined = False + + if not defined: # C++17 syntax: auto [a,b,...,z] = ... + self.write("auto [") + else: # C++14 syntax: std::tie(a,b,...,z) = ... + self.write("std::tie(") + + first = True + for target in targets: + if not first: + self.write(', ') + self.locals.define(target.id, target.lineno, self._indent) + self.dispatch(target) + first = False + + if not defined: + self.write("]") + else: + self.write(")") + + def _Assign(self, t): + self.fill() + + # Handle the case of a tuple output + if len(t.targets) > 1: + self.dispatch_lhs_tuple(t.targets) + else: + target = t.targets[0] + if isinstance(target, ast.Tuple): + if len(target.elts) > 1: + self.dispatch_lhs_tuple(target.elts) + target = target.elts[0] + + if not isinstance(target, + ast.Subscript) and not self.locals.is_defined( + target.id, self._indent): + self.locals.define(target.id, t.lineno, self._indent) + self.write('auto ') + self.dispatch(target) + + self.write(" = ") + self.dispatch(t.value) + self.write(';') + + def _AugAssign(self, t): + self.fill() + self.dispatch(t.target) + # Operations that require a function call + if t.op.__class__.__name__ in self.funcops: + separator, func = self.funcops[t.op.__class__.__name__] + self.write(" = " + func + "(") + self.dispatch(t.target) + self.write(separator + " ") + self.dispatch(t.value) + self.write(")") + else: + self.write(" " + self.binop[t.op.__class__.__name__] + "= ") + self.dispatch(t.value) + self.write(';') + + def _AnnAssign(self, t): + self.fill() + + if isinstance(t.target, ast.Tuple): + if len(t.target.elts) > 1: + self.dispatch_lhs_tuple(t.target.elts) + else: + target = target.elts[0] + else: + target = t.target + + # Assignment of the form x: int = 0 is converted to int x = (int)0; + if not self.locals.is_defined(target.id, self._indent): + self.locals.define(target.id, t.lineno, self._indent) + self.dispatch(t.annotation) + self.write(' ') + if not t.simple: + self.write("(") + self.dispatch(t.target) + if not t.simple: + self.write(")") + if t.value: + self.write(" = (") + self.dispatch(t.annotation) + self.write(")") + self.dispatch(t.value) + self.write(';') + + def _Return(self, t): + self.fill("return") + if t.value: + self.write(" ") + self.dispatch(t.value) + self.write(';') + + def _Pass(self, t): + raise SyntaxError('Invalid C++') + + def _Break(self, t): + self.fill("break;") + + def _Continue(self, t): + self.fill("continue;") + + def _Delete(self, t): + raise SyntaxError('Invalid C++') + + def _Assert(self, t): + self.fill("assert(") + self.dispatch(t.test) + if t.msg: + self.write(", ") + self.dispatch(t.msg) + self.write(");") + + def _Exec(self, t): + raise SyntaxError('Invalid C++') + + def _Print(self, t): + do_comma = False + if t.dest: + self.fill("fprintf(") + self.dispatch(t.dest) + do_comma = True + else: + self.fill("printf(") + + for e in t.values: + if do_comma: self.write(", ") + else: do_comma = True + self.dispatch(e) + if not t.nl: + self.write(",") + + self.write(');') + + def _Global(self, t): + raise SyntaxError('Invalid C++') + + def _Nonlocal(self, t): + raise SyntaxError('Invalid C++') + + def _Yield(self, t): + raise SyntaxError('Invalid C++') + + def _YieldFrom(self, t): + raise SyntaxError('Invalid C++') + + def _Raise(self, t): + self.fill("throw") + if six.PY3: + if not t.exc: + assert not t.cause + return + self.write(" ") + self.dispatch(t.exc) + if t.cause: + raise SyntaxError('Invalid C++') + else: + self.write(" ") + if t.type: + self.dispatch(t.type) + if t.inst: + self.write(", ") + self.dispatch(t.inst) + if t.tback: + self.write(", ") + self.dispatch(t.tback) + self.write(';') + + def _Try(self, t): + self.fill("try") + self.enter() + self.dispatch(t.body) + self.leave() + for ex in t.handlers: + self.dispatch(ex) + if t.orelse: + raise SyntaxError('Invalid C++') + if t.finalbody: + self.fill("finally") + self.enter() + self.dispatch(t.finalbody) + self.leave() + + def _TryExcept(self, t): + self.fill("try") + self.enter() + self.dispatch(t.body) + self.leave() + + for ex in t.handlers: + self.dispatch(ex) + if t.orelse: + raise SyntaxError('Invalid C++') + + def _TryFinally(self, t): + if len(t.body) == 1 and isinstance(t.body[0], ast.TryExcept): + # try-except-finally + self.dispatch(t.body) + else: + self.fill("try") + self.enter() + self.dispatch(t.body) + self.leave() + + self.fill("finally") + self.enter() + self.dispatch(t.finalbody) + self.leave() + + def _ExceptHandler(self, t): + self.fill("catch (") + if t.type: + self.dispatch(t.type) + if t.name: + if six.PY3: + self.write(t.name) + else: + self.dispatch(t.name) + self.write(')') + self.enter() + self.dispatch(t.body) + self.leave() + + def _ClassDef(self, t): + raise NotImplementedError('Classes are unsupported') + + # Original class definition from astunparse + #self.write("\n") + #for deco in t.decorator_list: + # self.fill("@") + # self.dispatch(deco) + #self.fill("class "+t.name) + #if six.PY3: + # self.write("(") + # comma = False + # for e in t.bases: + # if comma: self.write(", ") + # else: comma = True + # self.dispatch(e) + # for e in t.keywords: + # if comma: self.write(", ") + # else: comma = True + # self.dispatch(e) + # if sys.version_info[:2] < (3, 5): + # if t.starargs: + # if comma: self.write(", ") + # else: comma = True + # self.write("*") + # self.dispatch(t.starargs) + # if t.kwargs: + # if comma: self.write(", ") + # else: comma = True + # self.write("**") + # self.dispatch(t.kwargs) + # self.write(")") + #elif t.bases: + # self.write("(") + # for a in t.bases: + # self.dispatch(a) + # self.write(", ") + # self.write(")") + #self.enter() + #self.dispatch(t.body) + #self.leave() + + def _generic_FunctionDef(self, t, is_async=False): + self.write("\n") + for deco in t.decorator_list: + self.fill("// Decorator: ") + self.dispatch(deco) + if is_async: + self.write('/* async */ ') + + if getattr(t, "returns", False): + if isinstance(t.returns, ast.NameConstant): + if t.returns.value is None: + self.write('void') + else: + self.dispatch(t.returns) + else: + self.dispatch(t.returns) + + self.fill(" " + t.name + "(") + else: + self.fill("auto " + t.name + "(") + + self.dispatch(t.args) + + self.write(")") + self.enter() + self.dispatch(t.body) + self.leave() + + def _FunctionDef(self, t): + self._generic_FunctionDef(t) + + def _AsyncFunctionDef(self, t): + self._generic_FunctionDef(t, is_async=True) + + def _generic_For(self, t, is_async=False): + if is_async: + self.fill("/* async */ for (") + else: + self.fill("for (") + if isinstance(t.target, ast.Tuple): + self.write("auto ") + if len(t.target.elts) == 1: + (elt, ) = t.target.elts + self.locals.define(elt.id, t.lineno, self._indent + 1) + self.dispatch(elt) + else: + self.write("[") + interleave(lambda: self.write(", "), self.dispatch, + t.target.elts) + for elt in t.target.elts: + self.locals.define(elt.id, t.lineno, self._indent + 1) + self.write("]") + + else: + if not self.locals.is_defined(t.target.id, self._indent): + self.locals.define(t.target.id, t.lineno, self._indent + 1) + self.write('auto ') + self.dispatch(t.target) + + self.write(" : ") + self.dispatch(t.iter) + self.write(")") + self.enter() + self.dispatch(t.body) + self.leave() + if t.orelse: + raise SyntaxError('Invalid C++') + + def _For(self, t): + self._generic_For(t) + + def _AsyncFor(self, t): + self._generic_For(t, is_async=True) + + def _If(self, t): + self.fill("if (") + self.dispatch(t.test) + self.write(')') + self.enter() + self.dispatch(t.body) + self.leave() + # collapse nested ifs into equivalent elifs. + while (t.orelse and len(t.orelse) == 1 + and isinstance(t.orelse[0], ast.If)): + t = t.orelse[0] + self.fill("else if (") + self.dispatch(t.test) + self.write(')') + self.enter() + self.dispatch(t.body) + self.leave() + # final else + if t.orelse: + self.fill("else") + self.enter() + self.dispatch(t.orelse) + self.leave() + + def _While(self, t): + self.fill("while (") + self.dispatch(t.test) + self.write(')') + self.enter() + self.dispatch(t.body) + self.leave() + if t.orelse: + raise SyntaxError('Invalid C++') + + def _generic_With(self, t, is_async=False): + raise SyntaxError('Invalid C++') + + def _With(self, t): + self._generic_With(t) + + def _AsyncWith(self, t): + self._generic_With(t, is_async=True) + + # expr + def _Bytes(self, t): + self.write(repr(t.s)) + + def _Str(self, tree): + result = '' + if six.PY3: + result = repr(tree.s) + else: + # if from __future__ import unicode_literals is in effect, + # then we want to output string literals using a 'b' prefix + # and unicode literals with no prefix. + if "unicode_literals" not in self.future_imports: + result = repr(tree.s) + elif isinstance(tree.s, str): + result = "b" + repr(tree.s) + elif isinstance(tree.s, unicode): + result = repr(tree.s).lstrip("u") + else: + assert False, "shouldn't get here" + + self.write(result.replace('\'', '\"')) + + format_conversions = {97: 'a', 114: 'r', 115: 's'} + + def _FormattedValue(self, t): + # FormattedValue(expr value, int? conversion, expr? format_spec) + self.write("{") + self.dispatch(t.value) + if t.conversion is not None and t.conversion != -1: + self.write("!") + self.write(self.format_conversions[t.conversion]) + #raise NotImplementedError(ast.dump(t, True, True)) + if t.format_spec is not None: + self.write(":") + if isinstance(t.format_spec, ast.Str): + self.write(t.format_spec.s) + else: + self.dispatch(t.format_spec) + self.write("}") + + def _JoinedStr(self, t): + # JoinedStr(expr* values) + self.write("f'''") + for value in t.values: + if isinstance(value, ast.Str): + self.write(value.s) + else: + self.dispatch(value) + self.write("'''") + + def _Name(self, t): + if t.id in _py2c_reserved: + self.write(_py2c_reserved[t.id]) + else: + self.write(t.id) + + def _NameConstant(self, t): + self.write(_py2c_nameconst[t.value]) + + def _Repr(self, t): + raise SyntaxError('Invalid C++') + + def _Num(self, t): + repr_n = repr(t.n) + if six.PY3: + if repr_n.endswith("j"): + # FIXME: Complex is not a native type in C++, this type-hack should deduce the target type + self.write( + "dace::complexJ()*%s" % repr_n.replace("inf", INFSTR)[:-1]) + else: + self.write(repr_n.replace("inf", INFSTR)) + else: + # Parenthesize negative numbers, to avoid turning (-1)**2 into -1**2. + if repr_n.startswith("-"): + self.write("(") + if "inf" in repr_n and repr_n.endswith("*j"): + repr_n = repr_n.replace("*j", "j") + + if repr_n.endswith("j"): + # FIXME: Complex is not a native type in C++, this type-hack should deduce the target type + self.write( + "dace::complexJ()*%s" % repr_n.replace("inf", INFSTR)[:-1]) + else: + # Substitute overflowing decimal literal for AST infinities. + self.write(repr_n.replace("inf", INFSTR)) + + if repr_n.startswith("-"): + self.write(")") + + def _List(self, t): + raise SyntaxError('Invalid C++') + #self.write("[") + #interleave(lambda: self.write(", "), self.dispatch, t.elts) + #self.write("]") + + def _ListComp(self, t): + raise SyntaxError('Invalid C++') + #self.write("[") + #self.dispatch(t.elt) + #for gen in t.generators: + # self.dispatch(gen) + #self.write("]") + + def _GeneratorExp(self, t): + raise SyntaxError('Invalid C++') + #self.write("(") + #self.dispatch(t.elt) + #for gen in t.generators: + # self.dispatch(gen) + #self.write(")") + + def _SetComp(self, t): + raise SyntaxError('Invalid C++') + #self.write("{") + #self.dispatch(t.elt) + #for gen in t.generators: + # self.dispatch(gen) + #self.write("}") + + def _DictComp(self, t): + raise SyntaxError('Invalid C++') + #self.write("{") + #self.dispatch(t.key) + #self.write(": ") + #self.dispatch(t.value) + #for gen in t.generators: + # self.dispatch(gen) + #self.write("}") + + def _comprehension(self, t): + raise SyntaxError('Invalid C++') + #if getattr(t, 'is_async', False): + # self.write(" async") + #self.write(" for ") + #self.dispatch(t.target) + #self.write(" in ") + #self.dispatch(t.iter) + #for if_clause in t.ifs: + # self.write(" if ") + # self.dispatch(if_clause) + + def _IfExp(self, t): + self.write("(") + self.dispatch(t.test) + self.write(" ? ") + self.dispatch(t.body) + self.write(" : ") + self.dispatch(t.orelse) + self.write(")") + + def _Set(self, t): + raise SyntaxError('Invalid C++') + #assert(t.elts) # should be at least one element + #self.write("{") + #interleave(lambda: self.write(", "), self.dispatch, t.elts) + #self.write("}") + + def _Dict(self, t): + raise SyntaxError('Invalid C++') + #self.write("{") + #def write_pair(pair): + # (k, v) = pair + # self.dispatch(k) + # self.write(": ") + # self.dispatch(v) + #interleave(lambda: self.write(", "), write_pair, zip(t.keys, t.values)) + #self.write("}") + + def _Tuple(self, t): + self.write("std::make_tuple(") + if len(t.elts) == 1: + (elt, ) = t.elts + self.dispatch(elt) + self.write(",") + else: + interleave(lambda: self.write(", "), self.dispatch, t.elts) + self.write(")") + + unop = {"Invert": "~", "Not": "!", "UAdd": "+", "USub": "-"} + + def _UnaryOp(self, t): + self.write("(") + self.write(self.unop[t.op.__class__.__name__]) + self.write(" ") + if six.PY2 and isinstance(t.op, ast.USub) and isinstance( + t.operand, ast.Num): + # If we're applying unary minus to a number, parenthesize the number. + # This is necessary: -2147483648 is different from -(2147483648) on + # a 32-bit machine (the first is an int, the second a long), and + # -7j is different from -(7j). (The first has real part 0.0, the second + # has real part -0.0.) + self.write("(") + self.dispatch(t.operand) + self.write(")") + else: + self.dispatch(t.operand) + self.write(")") + + binop = { + "Add": "+", + "Sub": "-", + "Mult": "*", + "Div": "/", + "Mod": "%", + "LShift": "<<", + "RShift": ">>", + "BitOr": "|", + "BitXor": "^", + "BitAnd": "&" + } + funcops = { + "FloorDiv": (" /", "dace::math::ifloor"), + "MatMult": (",", "dace::gemm") + } + + def _BinOp(self, t): + # Operations that require a function call + if t.op.__class__.__name__ in self.funcops: + separator, func = self.funcops[t.op.__class__.__name__] + self.write(func + "(") + self.dispatch(t.left) + self.write(separator + " ") + self.dispatch(t.right) + self.write(")") + # Special case for integer power + elif t.op.__class__.__name__ == 'Pow': + if (isinstance(t.right, ast.Num) and int(t.right.n) == t.right.n + and t.right.n >= 0): + self.write("(") + if t.right.n == 0: + self.write("1") + else: + self.dispatch(t.left) + for i in range(int(t.right.n) - 1): + self.write(" * ") + self.dispatch(t.left) + self.write(")") + else: + self.write("dace::math::pow(") + self.dispatch(t.left) + self.write(", ") + self.dispatch(t.right) + self.write(")") + else: + self.write("(") + self.dispatch(t.left) + self.write(" " + self.binop[t.op.__class__.__name__] + " ") + self.dispatch(t.right) + self.write(")") + + cmpops = { + "Eq": "==", + "NotEq": "!=", + "Lt": "<", + "LtE": "<=", + "Gt": ">", + "GtE": ">=", + "Is": "==", + "IsNot": "!=", + #"In":"in", "NotIn":"not in" + } + + def _Compare(self, t): + self.write("(") + self.dispatch(t.left) + for o, e in zip(t.ops, t.comparators): + if o.__class__.__name__ not in self.cmpops: + raise SyntaxError('Invalid C++') + + self.write(" " + self.cmpops[o.__class__.__name__] + " ") + self.dispatch(e) + self.write(")") + + boolops = {ast.And: '&&', ast.Or: '||'} + + def _BoolOp(self, t): + self.write("(") + s = " %s " % self.boolops[t.op.__class__] + interleave(lambda: self.write(s), self.dispatch, t.values) + self.write(")") + + def _Attribute(self, t): + self.dispatch(t.value) + # Special case: 3.__abs__() is a syntax error, so if t.value + # is an integer literal then we need to either parenthesize + # it or add an extra space to get 3 .__abs__(). + if isinstance(t.value, ast.Num) and isinstance(t.value.n, int): + self.write(" ") + self.write(".") + self.write(t.attr) + + def _Call(self, t): + self.dispatch(t.func) + self.write("(") + comma = False + for e in t.args: + if comma: self.write(", ") + else: comma = True + self.dispatch(e) + for e in t.keywords: + if comma: self.write(", ") + else: comma = True + self.dispatch(e) + if sys.version_info[:2] < (3, 5): + if t.starargs: + raise SyntaxError('Invalid C++') + if t.kwargs: + raise SyntaxError('Invalid C++') + self.write(")") + + def _Subscript(self, t): + self.dispatch(t.value) + self.write("[") + self.dispatch(t.slice) + self.write("]") + + def _Starred(self, t): + raise SyntaxError('Invalid C++') + + # slice + def _Ellipsis(self, t): + self.write("...") + + def _Index(self, t): + self.dispatch(t.value) + + def _Slice(self, t): + if t.lower: + self.dispatch(t.lower) + self.write(":") + if t.upper: + self.dispatch(t.upper) + if t.step: + self.write(":") + self.dispatch(t.step) + + def _ExtSlice(self, t): + interleave(lambda: self.write(', '), self.dispatch, t.dims) + + # argument + def _arg(self, t): + if t.annotation: + self.dispatch(t.annotation) + self.write(' ') + else: + self.write("auto ") + self.write(t.arg) + self.locals.define(t.arg, t.lineno, self._indent) + + # others + def _arguments(self, t): + first = True + # normal arguments + defaults = [None] * (len(t.args) - len(t.defaults)) + t.defaults + for a, d in zip(t.args, defaults): + if first: first = False + else: self.write(", ") + + # ast.arg does not exist in python2 + if six.PY2: + self.write("auto ") + self.locals.define(a.id, a.lineno, self._indent) + + self.dispatch(a) + if d: + self.write("=") + self.dispatch(d) + + # varargs, or bare '*' if no varargs but keyword-only arguments present + if t.vararg or getattr(t, "kwonlyargs", False): + raise SyntaxError('Invalid C++') + + # keyword-only arguments + if getattr(t, "kwonlyargs", False): + raise SyntaxError('Invalid C++') + + # kwargs + if t.kwarg: + raise SyntaxError('Invalid C++') + + def _keyword(self, t): + raise SyntaxError('Invalid C++') + + def _Lambda(self, t): + self.write("(") + self.write("[] (") + self.dispatch(t.args) + self.write(") { return ") + self.dispatch(t.body) + self.write("; } )") + + def _alias(self, t): + self.write('using ') + self.write(t.name) + if t.asname: + self.write(" = " + t.asname) + self.write(';') + + def _withitem(self, t): + raise SyntaxError('Invalid C++') + + def _Await(self, t): + raise SyntaxError('Invalid C++') + + +def cppunparse(node, expr_semicolon=True): + strio = StringIO() + CPPUnparser(node, 0, CPPLocals(), strio, expr_semicolon=expr_semicolon) + return strio.getvalue().strip() + + +# Code can either be a string or a function +def py2cpp(code, expr_semicolon=True): + if isinstance(code, str): + return cppunparse(ast.parse(code), expr_semicolon) + elif code.__class__.__name__ == 'function': + try: + code_str = inspect.getsource(code) + + # Remove leading indentation + lines = code_str.splitlines() + leading_spaces = len(lines[0]) - len(lines[0].lstrip()) + code_str = '' + for line in lines: + code_str += line[leading_spaces:] + '\n' + + except: # Can be different exceptions coming from Python's AST module + raise TypeError('Invalid function given') + return cppunparse(ast.parse(code_str), expr_semicolon) + + else: + raise TypeError('Unsupported type for py2cpp') + + +def pyexpr2cpp(expr): + return py2cpp(expr, expr_semicolon=False) diff --git a/dace/codegen/instrumentation/__init__.py b/dace/codegen/instrumentation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dace/codegen/instrumentation/perfsettings.py b/dace/codegen/instrumentation/perfsettings.py new file mode 100644 index 0000000000..00982de3d3 --- /dev/null +++ b/dace/codegen/instrumentation/perfsettings.py @@ -0,0 +1,1587 @@ +from dace.graph.nodes import MapEntry, MapExit, Tasklet +from dace.graph.graph import SubgraphView +from dace.memlet import Memlet +from dace.data import Array + +from dace.config import Config + +from dace.types import ScheduleType + +import re + +import sympy as sp + +# Helper function to get the module path +if __name__ == "__main__": + import os + print("path: " + os.path.dirname(__file__)) + + +class PerfSettings(object): + + _unique_counter = 0 + + _perf_enable_instrumentation = True + perf_enable_override_config = True + + #default_papi_counters = ["PAPI_TOT_INS", "PAPI_TOT_CYC", "PAPI_L1_TCM", "PAPI_L2_TCM", "PAPI_L3_TCM"] + default_papi_counters = [ + "PAPI_TOT_INS", "PAPI_TOT_CYC", "PAPI_L2_TCM", "PAPI_L3_TCM" + ] + + @staticmethod + def get_unique_number(): + ret = PerfSettings._unique_counter + PerfSettings._unique_counter = PerfSettings._unique_counter + 1 + return ret + + @staticmethod + def perf_multirun_num(): + """ Amount of iterations with different PAPI configurations to run. (1 means no multirun) """ + if not PerfSettings.perf_enable_instrumentation(): + return 1 + return 4 + + @staticmethod + def perf_multirun_options(): + """ Specifies the options for "multirunning": running the same program + multiple times with different performance counters. """ + ret = [] + + if PerfSettings.perf_multirun_num() == 1: + return ret # Don't specify these options by default + + for i in range(0, 4): + ret.append(("omp_num_threads", i + 1)) + return ret + + @staticmethod + def perf_default_papi_counters(): + return eval(Config.get("instrumentation", "default_papi_counters")) + + @staticmethod + def perf_enable_instrumentation(): + return Config.get_bool("instrumentation", "enable_papi") + + @staticmethod + def perf_enable_instrumentation_for(sdfg, node=None): + return PerfSettings.perf_enable_instrumentation( + ) and not sdfg.has_instrumented_parent() + + @staticmethod + def perf_supersection_emission_debug(): + return True + + @staticmethod + def perf_enable_counter_sanity_check(): + return Config.get_bool("instrumentation", + "enable_papi_counter_sanity_check") + + @staticmethod + def perf_print_instrumentation_output(): + return False + + @staticmethod + def perf_enable_vectorization_analysis(): + return Config.get_bool("instrumentation", + "enable_vectorization_analysis") + + @staticmethod + def perf_max_scope_depth(): + # This variable selects the maximum depth inside a scope. For example, + # "map { map {}}" with max_scope_depth 0 will result in + # "map { profile(map{}) }", while max_scope_depth >= 1 result in + # "map { map { profile() }}" + return Config.get("instrumentation", "max_scope_depth") + + perf_debug_profile_innermost = False # innermost = False implies outermost + perf_debug_annotate_scopes = True + perf_debug_annotate_memlets = False + perf_debug_hard_error = False # If set to true, untreated cases cause program abort. + + #TODO: There should be a variable per MAP-Element that overrides the scope depth + perf_tasklets = False + + perf_whitelist_schedules = [ + ScheduleType.Default, ScheduleType.CPU_Multicore, + ScheduleType.Sequential + ] + + +class PerfUtils(object): + @staticmethod + def unified_id(node_id, state_id): + if node_id > 0x0FFFF: + raise ValueError("Nodeid is too larget to fit in 16 bits!") + if state_id > 0x0FFFF: + raise ValueError("Stateid is too large to fit in 16 bits!") + return (int(state_id) << 16) | int(node_id) + + @staticmethod + def gather_remote_metrics(): + """ Returns a dictionary of metrics collected by instrumentation. """ + + # Run the tools/membench file on remote. + remote_workdir = Config.get("execution", "general", "workdir") + from diode.remote_execution import Executor + from string import Template + import subprocess + executor = Executor(None, True, None) + + remote_filepath = remote_workdir + "/" + "membench.cpp" + + executor.copy_file_to_remote("tools/membench.cpp", remote_filepath) + + libs = Config.get("compiler", "cpu", "libs").split(" ") + + libflags = map(lambda x: "-l" + x, libs) + + libflagstring = "".join(libflags) + + path_resolve_command = "python3 -m dace.codegen.instrumentation.perfsettings" + # Get the library path + s = Template(Config.get("execution", "general", "execcmd")) + cmd = s.substitute( + host=Config.get("execution", "general", "host"), + command=path_resolve_command) + + p = subprocess.Popen( + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True) + + stdout, _ = p.communicate(timeout=60) + + remote_dace_path = re.search(r"path: (?P.*)", str(stdout)) + if remote_dace_path: + remote_dace_path = remote_dace_path['dace_path'] + print("Remote dace path: %s" % remote_dace_path) + + # Now create the include path from that + include_path = "\"" + remote_dace_path + "/" + "runtime/include" + "\"" + + print("remote_workdir: " + remote_workdir) + compile_and_run_command = "cd " + remote_workdir + " && " + " pwd && " + Config.get( + "compiler", "cpu", "executable" + ) + " " + Config.get( + "compiler", "cpu", "args" + ) + " " + "-fopenmp" + " " + Config.get( + "compiler", "cpu", "additional_args" + ) + " -I" + include_path + " " + "membench.cpp -o membench" + " " + libflagstring + " && " + "./membench" + + # Wrap that into a custom shell because ssh will not keep context. + # The HEREDOC is needed because we already use " and ' inside the command. + compile_and_run_command = "<< EOF\nsh -c '" + compile_and_run_command + "'" + "\nEOF" + + print("Compile command is " + compile_and_run_command) + + # run this command + s = Template(Config.get("execution", "general", "execcmd")) + cmd = s.substitute( + host=Config.get("execution", "general", "host"), + command=compile_and_run_command) + + p2 = subprocess.Popen( + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True) + + stdout2, _ = p2.communicate(timeout=60) + + #print("stdout2: " + str(stdout2)) + + bytes_per_cycle = re.search(r"result: (?P.*?$)", + str(stdout2)) + if bytes_per_cycle: + bytes_per_cycle = bytes_per_cycle['bytes_per_cycle'] + print("Bytes per cycle: %s" % bytes_per_cycle) + + executor.remote_delete_file(remote_workdir + "/membench.cpp") + executor.remote_delete_file(remote_workdir + "/membench") + + return bytes_per_cycle + + @staticmethod + def reduce_iteration_count(begin, end, step, retparams: dict): + + from dace.symbolic import symbols_in_sympy_expr, SymExpr + + # There are different rules when expanding depending on where the expand should happen + start_syms = symbols_in_sympy_expr(begin) + end_syms = symbols_in_sympy_expr(end) + step_syms = symbols_in_sympy_expr(step) + + def intersection(lista, listb): + return [x for x in lista if x in listb] + + start_dyn_syms = intersection(start_syms, retparams.keys()) + end_dyn_syms = intersection(end_syms, retparams.keys()) + step_dyn_syms = intersection(step_syms, retparams.keys()) + + def replace_func(element, dyn_syms, retparams): + print("Dynamic element symbols symbols: %s (out of %s)!" % + (str(element), str(dyn_syms))) + print("(srepr): " + sp.srepr(element)) + # Resolve all symbols using the retparams-dict + + for x in dyn_syms: + print("Replacing " + str(x)) + target = sp.functions.Min( + retparams[x] * (retparams[x] - 1) / 2, 0) + print("\twith target " + str(target)) + bstr = str(element) + #print(bstr) + element = sp.sympify(bstr, sp.abc._clash) + #print("\t(new srepr): " + sp.srepr(element)) + element = element.subs( + x, target) # Add the classic sum formula; going upwards + + # To not have hidden elements that get added again later, we also replace the values in the other itvars... + for k, v in retparams.items(): + newv = sp.sympify(str(v), sp.abc._clash) + + itsyms = symbols_in_sympy_expr(newv) + tarsyms = symbols_in_sympy_expr(target) + if x in map(str, tarsyms): + continue + # assert not x in itsyms # We never want to have the replaced symbol in its own expression. This can happen when applying 2 SMs + + tmp = newv.subs(x, target) + if tmp != v: + print("Replacing %s with %s" % (str(newv), str(tmp))) + retparams[k] = tmp + + print("\t New element: " + str(element)) + return element + + if len(start_dyn_syms) > 0: + pass + begin = replace_func(begin, start_dyn_syms, retparams) + + if len(end_dyn_syms) > 0: + pass + end = replace_func(end, end_dyn_syms, retparams) + + if len(step_dyn_syms) > 0: + pass + print("Dynamic step symbols %s!" % str(step)) + raise NotImplementedError + + return (begin, end, step) + + @staticmethod + def get_iteration_count(mapEntry: MapEntry, vars: dict): + """ Get the number of iterations for this map, allowing other variables as bounds. """ + from dace.symbolic import symbols_in_sympy_expr, SymExpr + + _map = mapEntry.map + _it = _map.params + + retparams = dict() + for k, v in vars.items(): + retparams[k] = v + + #print("Params: " + str(_it)) + for i, r in enumerate(_map.range): + begin, end, step = r + + end = end + 1 # end is inclusive, but we want it exclusive + + if isinstance(begin, SymExpr): + begin = begin.expr + if isinstance(end, SymExpr): + end = end.expr + if isinstance(step, SymExpr): + step = step.expr + + begin, end, step = PerfUtils.reduce_iteration_count( + begin, end, step, retparams) + num = (end - begin) / step # The count of iterations + retparams[_it[i]] = num + + return retparams + + @staticmethod + def all_maps(mapEntry: MapEntry, dfg: SubgraphView): + children = [ + x for x in dfg.scope_dict(True)[mapEntry] + if isinstance(x, MapEntry) + ] + + sub = [] + for x in children: + sub.extend(PerfUtils.all_maps(x, dfg)) + + children.extend(sub) + #children.extend([PerfUtils.all_maps(x, dfg) for x in children]) + return children + + @staticmethod + def map_depth(mapEntry: MapEntry): + # Returns the depth of this entry node. + # For now, the depth is stored inside the MapEntry node. + return mapEntry._map_depth + + @staticmethod + def set_map_depth(mapEntry: MapEntry, DFG: SubgraphView): + from dace.graph.nodes import Reduce, AccessNode, NestedSDFG + + # Set the depth for the mapEntry + + # We do not use mapEntry for now, but it might be required for different implementations + + # Get the sorted graph + dfg_sorted = DFG.topological_sort() + depth = 0 + following_nodes_invalid = False # Set to True when a fencing map is encountered + invalid_scope = -1 + invalid_index = PerfSettings.perf_max_scope_depth() + 1 + # Iterate and get the depth for every node, breaking when the specified node has been found + for e in dfg_sorted: + # Set the depth for every node on the way + if isinstance(e, MapEntry): + if not following_nodes_invalid and not e.map.schedule in PerfSettings.perf_whitelist_schedules: + print( + "Cannot instrument node %s, as it is running on a GPU (schedule %s)" + % (str(mapEntry), e.map.schedule)) + following_nodes_invalid = True # Invalidate all following maps + invalid_scope = depth + 1 # Mark this depth as invalid. Once the depth drops below this threshold, the invalid-mark will be removed + if following_nodes_invalid and depth: + e._map_depth = invalid_index # Set an invalid index (this will never be instrumented) + else: + e._map_depth = max(e._map_depth, depth) + if e.fence_instrumentation: + following_nodes_invalid = True # After a fence there must not be any instrumentation happening + + depth += 1 + elif isinstance(e, MapExit): + depth -= 1 + if depth < invalid_scope: + invalid_scope = -1 + following_nodes_invalid = False + elif isinstance(e, NestedSDFG): + e.sdfg.set_instrumented_parent() + #depth += 1 # Not sure if we should add a depth here + + pass + else: + if isinstance(e, Reduce): + pass + elif isinstance(e, AccessNode): + pass + elif isinstance(e, Tasklet): + pass + else: + print("Error-Type: " + type(e).__name__) + assert False + + @staticmethod + def is_deepest_node(check: MapEntry, DFG: SubgraphView): + nodes = DFG.nodes() + checkdepth = PerfUtils.map_depth(check) + return all( + not isinstance(x, MapEntry) or PerfUtils.map_depth(x) <= checkdepth + for x in nodes) + + @staticmethod + def instrument_entry(mapEntry: MapEntry, DFG: SubgraphView): + depth = PerfUtils.map_depth(mapEntry) + cond1 = PerfSettings.perf_enable_instrumentation( + ) and depth <= PerfSettings.perf_max_scope_depth() and ( + PerfUtils.is_deepest_node(mapEntry, DFG) + or depth == PerfSettings.perf_max_scope_depth()) + cond2 = mapEntry.map.schedule in PerfSettings.perf_whitelist_schedules + cond3 = not mapEntry.fence_instrumentation + if not cond2: + print("Cannot instrument node %s, as it is running on a GPU" % + str(mapEntry)) + return cond1 and cond2 and cond3 + + @staticmethod + def has_surrounding_perfcounters(node, DFG: SubgraphView): + """ Returns true if there is a possibility that this node is part of a + section that is profiled. """ + parent = DFG.scope_dict()[node] + + if isinstance(parent, MapEntry): + if parent.map._has_papi_counters or PerfUtils.map_depth( + parent) > PerfSettings.perf_max_scope_depth(): + return True + + return False + + @staticmethod + def get_memlet_byte_size(sdfg, memlet: Memlet): + pass + memdata = sdfg.arrays[memlet.data] + # For now, deal with arrays only + if isinstance(memdata, Array): + elems = [str(memdata.dtype.bytes)] + # The following for-loop is not relevant here, it just describes the shape of the source... + #for x in memdata.shape: + # elems.append(str(x)) + try: + if (memlet.num_accesses >= 0): + elems.append( + str(memlet.num_accesses) + ) # num_accesses seems to be the amount of accesses per tasklet execution + else: + print( + "Refusing to add negative accesses (%d) in get_memlet_byte_size!" + % memlet.num_accesses) + except: + print("Unsupported memlet.num_accesses type, %s (%s)" % (str( + type(memlet.num_accesses)), str(memlet.num_accesses))) + + return "(" + "*".join(elems) + ")" + + else: + print("Untreated data type: ", type(memdata).__name__) + if PerfSettings.perf_debug_hard_error: + assert False + else: + return "0" + + @staticmethod + def get_out_memlet_costs(sdfg, state_id, node, dfg): + from dace.graph import nodes + from dace.sdfg import ScopeSubgraphView, SDFG, scope_contains_scope + scope_dict = sdfg.nodes()[state_id].scope_dict() + + out_costs = 0 + for edge in dfg.out_edges(node): + _, uconn, v, _, memlet = edge + dst_node = dfg.memlet_path(edge)[-1].dst + + # Target is neither a data nor a tasklet node + if (isinstance(node, nodes.AccessNode) + and (not isinstance(dst_node, nodes.AccessNode) + and not isinstance(dst_node, nodes.CodeNode))): + continue + + # Skip array->code (will be handled as a tasklet input) + if isinstance(node, nodes.AccessNode) and isinstance( + v, nodes.CodeNode): + continue + + # code->code (e.g., tasklet to tasklet) + if isinstance(v, nodes.CodeNode): + shared_data_name = 's%d_n%d%s_n%d%s' % ( + state_id, dfg.node_id(edge.src), edge.src_conn, + dfg.node_id(edge.dst), edge.dst_conn) + #result.write('__%s = %s;' % (shared_data_name, edge.src_conn), + # sdfg, state_id, [edge.src, edge.dst]) + # TODO: Check how to deal with this... + #raise NotImplementedError + continue + + # If the memlet is not pointing to a data node (e.g. tasklet), then + # the tasklet will take care of the copy + if not isinstance(dst_node, nodes.AccessNode): + continue + # If the memlet is pointing into an array in an inner scope, then the + # inner scope (i.e., the output array) must handle it + if (scope_dict[node] != scope_dict[dst_node] + and scope_contains_scope(scope_dict, node, dst_node)): + continue + + # Array to tasklet (path longer than 1, handled at tasklet entry) + if node == dst_node: + continue + + # Tasklet -> array + if isinstance(node, nodes.CodeNode): + if not uconn: + print("This would normally raise a syntax error!") + return 0 # We don't error-out because the error will be raised later + + try: + positive_accesses = bool(memlet.num_accesses >= 0) + except TypeError: + positive_accesses = False + + if memlet.subset.data_dims() == 0 and positive_accesses: + + if memlet.wcr is not None: + # write_and_resolve + # We have to assume that every reduction costs 3 accesses of the same size + out_costs += 3 * sp.sympify( + PerfUtils.get_memlet_byte_size(sdfg, memlet), + sp.abc._clash) + else: + #'%s.write(%s);\n' + # This standard operation is already counted + out_costs += sp.sympify( + PerfUtils.get_memlet_byte_size(sdfg, memlet), + sp.abc._clash) + # Dispatch array-to-array outgoing copies here + elif isinstance(node, nodes.AccessNode): + pass + return out_costs + + @staticmethod + def get_tasklet_byte_accesses(tasklet: Tasklet, dfg: SubgraphView, sdfg, + state_id): + """ Get the amount of bytes processed by `tasklet`. The formula is + sum(inedges * size) + sum(outedges * size) """ + in_accum = [] + out_accum = [] + in_edges = dfg.in_edges(tasklet) + out_edges = dfg.out_edges(tasklet) + + for ie in in_edges: + # type ie.data == Memlet + # type ie.data.data == Data + in_accum.append(PerfUtils.get_memlet_byte_size(sdfg, ie.data)) + + out_accum.append( + str(PerfUtils.get_out_memlet_costs(sdfg, state_id, tasklet, dfg))) + + # Merge (kept split to be able to change the behavior easily) + full = in_accum + full.extend(out_accum) + + return "(" + "+".join(full) + ")" + + @staticmethod + def get_map_exit_byte_accesses(mapexit: MapExit, dfg: SubgraphView, sdfg, + state_id): + """ Get the amount of bytes processed by mapexit. The formula is + sum(inedges * size) + sum(outedges * size) """ + in_accum = [] + out_accum = [] + in_edges = dfg.in_edges(mapexit) + out_edges = dfg.out_edges(mapexit) + + out_connectors = mapexit.out_connectors + + for ie in in_edges: + # type ie.data == Memlet + # type ie.data.data == Data + in_accum.append(PerfUtils.get_memlet_byte_size(sdfg, ie.data)) + + for oe in out_edges: + out_accum.append(PerfUtils.get_memlet_byte_size(sdfg, oe.data)) + + # Merge (kept split to be able to change the behavior easily) + full = in_accum + full.extend(out_accum) + + return "(" + "+".join(full) + ")" + + @staticmethod + def get_parents(outermost_node, node, sdfg, state_id): + + parent = None + # Because dfg is only a subgraph view, it does not contain the entry + # node for a given entry. This O(n) solution is suboptimal + for state in sdfg.nodes(): + s_d = state.scope_dict(node_to_children=False) + try: + scope = s_d[node] + except KeyError as e: + continue + + if (scope != None): + parent = scope + break + if (parent == None): + return [] + if (parent == outermost_node): + return [parent] + + return PerfUtils.get_parents(outermost_node, parent, sdfg, + state_id) + [parent] + + @staticmethod + def accumulate_byte_movements_v2(outermost_node, node, dfg: SubgraphView, + sdfg, state_id): + + itvars = dict() # initialize an empty dict + + # First, get a list of children + if isinstance(node, MapEntry): + children = dfg.scope_dict(node_to_children=True)[node] + else: + children = [] + assert not (node in children) + + # If there still are children, descend recursively (dfs is fine here) + if len(children) > 0: + size = 0 + for x in children: + size = size + PerfUtils.accumulate_byte_movements_v2( + outermost_node, x, dfg, sdfg, state_id) + + return size + else: + if isinstance(node, MapExit): + return 0 # We can ignore this. + + # If we reached the deepest node, get all parents + parent_list = PerfUtils.get_parents(outermost_node, node, sdfg, + state_id) + #print("Parents are " + str(parent_list)) + if isinstance(node, MapEntry): + map_list = parent_list + [node] + else: + #print("node is of type " + type(node).__name__) + map_list = parent_list + + # From all iterations, get the iteration count, replacing inner + # iteration variables with the next outer variables. + for x in map_list: + itvars = PerfUtils.get_iteration_count(x, itvars) + + #print("itvars: " + str(itvars)) + + itcount = 1 + for x in itvars.values(): + itcount = itcount * x + #print("Probable itcount: " + str(itcount)) + + #print("constants: " + str(sdfg.constants)) + + if isinstance(node, MapEntry): + raise ValueError( + "Unexpected node" + ) # A map entry should never be the innermost node + elif isinstance(node, MapExit): + return 0 # We can ignore this. + elif isinstance(node, Tasklet): + return itcount * sp.sympify( + PerfUtils.get_tasklet_byte_accesses( + node, dfg, sdfg, state_id)) + else: + if PerfSettings.perf_debug_hard_error: + raise NotImplementedError + else: + return 0 + + @staticmethod + def accumulate_byte_movements(node, dfg: SubgraphView, sym2cpp, sdfg, + state_id): + """ Loops over all sub-iterations and calculates the number of bytes + moved (logically). """ + + # The coefficient consists of multipliers (i.e. maps) and bytes (i.e. + # memlet/tasklet movements) + coeff_this_node = "" + + if isinstance(node, MapEntry): + # get the iteration count for this entry + coeff_this_node = '*'.join([ + '((%s - %s) / %s)' % (sym2cpp(re + 1), sym2cpp(rb), + sym2cpp(rs)) + for rb, re, rs in node.map.range + ]) + + # Create a list to contain all suboperations (for this scope) + subops = [coeff_this_node] + + for edge in dfg.edges(): + source = dfg.scope_dict()[edge.src] + destination = dfg.scope_dict()[edge.dst] + if source == node and edge.dst != node: + subops.append( + PerfUtils.accumulate_byte_movements( + edge.dst, dfg, sym2cpp, sdfg, state_id)) + if destination == node and edge.src != node: + subops.append( + PerfUtils.accumulate_byte_movements( + edge.src, dfg, sym2cpp, sdfg, state_id)) + + # We can just simplify that directly + if any(x == "0" for x in subops): + return "0" + coeff_this_node = ' * '.join([x for x in subops if x != ""]) + return coeff_this_node + elif isinstance(node, MapExit): + # Ignore this type, we already dealt with it when we processed + # MapEntry + return "" + elif isinstance(node, Tasklet): + # Exact data movement costs depend on the tasklet code + return PerfUtils.get_tasklet_byte_accesses(node, dfg, sdfg, + state_id) + + else: + if PerfSettings.perf_debug_hard_error: + raise NotImplementedError + else: + return "0" + + class ParseStates: + CONTROL = 0 + VALUES = 1 + SECTION_SIZE = 2 + + class Entry: + def __init__(self): + pass + self.values = {} + self.nodeid = 0 + self.coreid = 0 + self.iteration = 0 + self.flags = 0 + + def is_valid(self): + return len(self.values) != 0 + + def add(self, counter, value): + self.values[counter] = value + + def get(self, name: str): + try: + return self.values[name] + except: + return None + + def toJSON(self): + return '{{ "node": "{node}",\n"thread": "{thread}",\n"iteration": "{iteration}",\n"flags": {flags},\n"values": [{values}]\n}}\n'.format( + node=str(self.nodeid), + thread=str(self.coreid), + iteration=str(self.iteration), + flags=str(self.flags), + values=", ".join([ + '{{ "{code}": {value} }}'.format( + code=str(code), value=str(value)) + for code, value in self.values.items() + ])) + + def toCSVsubstring(self, delim=','): + return delim.join([ + self.nodeid, self.coreid, self.iteration, + *self.values.values() + ]) # * == ... in other languages + + class Section: + def __init__(self, nodeid=0, threadid=0): + pass + self.entries = [] + self.nodeid = nodeid + self.datasize = 0 + self.bytes_moved = 0 + self.was_collapsed = False + self.threadid = threadid + + def is_complete(self): + """ Checks if all iterations are in this section. This might not + always be the case, e.g. in filtered sections. """ + itlist = [int(x.iteration) for x in self.entries] + sortitlist = sorted(itlist) + for i, e in enumerate(sortitlist): + if (i != int(e)): + print("list: %s\n" % sortitlist) + return False + return True + + def is_valid(self): + return len(self.entries) != 0 + + def add(self, e): + self.entries.append(e) + + def addSection(self, sec): + """ Merges another section into this section. """ + assert self.nodeid == sec.nodeid + + # We allow collapsing at most once. + if self.was_collapsed: + return + if sec.was_collapsed: + return + # Add all entries + for x in sec.entries: + self.add(x) + + # merge meta + #self.datasize += sec.datasize + self.bytes_moved += sec.bytes_moved + self.was_collapsed = True + sec.was_collapsed = True + + def select_event(self, event: str): + """ Selects all values of 'event' in correct order from all + entries. """ + return [ + int(x.get(event)) for x in self.entries if x.get(event) != None + ] + + def select_thread(self, thread: int): + """ Returns a section that only contains entries of `self` that + were obtained in the given thread. """ + ret = PerfUtils.Section(self.nodeid) + + for x in self.entries: + if int(x.coreid) == int(thread): + ret.entries.append(x) + + return ret + + def select_node(self, node: int): + """ Returns a section that only contains entries of `self` that + were obtained for the given node """ + ret = PerfUtils.Section(self.nodeid) + + for x in self.entries: + if int(x.nodeid) == int(node): + ret.entries.append(x) + + return ret + + def filter(self, predicate): + """ Returns a section that only contains entries `e` for which + `predicate(e)` returns true""" + ret = PerfUtils.Section(self.nodeid) + + for x in self.entries: + if predicate(x): + ret.entries.append(x) + + return ret + + def get_max_thread_num(self): + """ Returns the maximal thread number in at most O(n) + complexity. """ + max = 0 + for x in self.entries: + if int(x.coreid) > max: + max = int(x.coreid) + return max + + def toCSVsubstring(self, prepend="", delim=',', linedelim='\n'): + ret = "" + for x in self.entries: + ret += delim.join([ + prepend, "node" + self.nodeid, self.threadid, + x.toCSVsubstring(delim) + ]) + linedelim + return ret + + def toJSON(self): + return '{{ "entry_node": {entry_node}, "static_movement": {datasize}, "entry_core": {core}, "entries": ['.format( + entry_node=self.nodeid, + datasize=self.datasize, + core=self.threadid) + ", ".join( + [x.toJSON() for x in self.entries]) + "]}" + + class SuperSection: + """ Contains multiple Sections. + @see Section + """ + + def __init__(self, supernode=0): + self.sections = {} + self.supernode = supernode + + def is_valid(self): + return len(self.sections.values()) > 0 + + def addSection(self, section): + if int(section.threadid) in self.sections: + self.sections[int(section.threadid)].append(section) + else: + self.sections[int(section.threadid)] = [section] + + def addEntry(self, entry): + + if not entry.is_valid(): + # ignore invalid entries + return + + # We have 2 cases - either: + # (a) the section starts outside of a parallel block: + # Every entry needs to be assigned to this block. There will only + # be one block with threadid == 0 in this case. + # or (b) the section starts in a parallel block: + # Entries can be assigned by thread_id. + if int(entry.coreid) in self.sections: + # Assign by thread id + try: + self.sections[int(entry.coreid)][-1].add(entry) + except: + print("Sections has keys " + str(self.sections.keys())) + raise + else: + # Ideally, we can only add nodes to a section if they have the + # same core id. However, in nested omp constructs, the + # lower-level sections are usually just run on core 0. + # So if a section starts on core 1, its entries might still + # report core 0. + try: + self.sections[0][-1].add(entry) + except Exception as e: + print("error, contained sections:") + print(str(self.sections)) + print(str(self.sections.values())) + + mitigated = False + # Find the section that matches by nodeid... + for x in self.sections.values(): + # Find the correct section and append to that + # (start with oldest entry) + for y in reversed(x): + if y.nodeid == entry.nodeid: + y.add(entry) + print( + "Warning: Mitigation successful, but you should probably enable OMP_NESTED" + ) + mitigated = True + break + + if not mitigated: # Only complain if we could not mitigate + raise e + + def getSections(self): + l = [] + for x in self.sections.values(): + l.extend(x) + return [x for x in l] + + def toCSVstring(self, delim=',', linedelim='\n'): + """ Create a CSV string from the data. """ + + # Squashes everything into a row, duplicating data. + ret = "" + for x in self.sections.values(): + for y in x: + ret += y.toCSVsubstring("supernode" + str(self.supernode), + delim, linedelim) + ret += "ENDSUPERSECTION" + linedelim + return ret + + def toJSON(self): + return '{{ "hint": "supersection", "supernode": {supernode},\n "sections": [{sections}] }}'.format( + supernode=self.supernode, + sections=",\n".join([x.toJSON() for x in self.getSections()])) + + @staticmethod + def perf_counter_store_string(counterlist: [str]): + """ Creates a performance counter typename string. """ + return "PAPIValueStore<" + ", ".join(counterlist) + ">" + + @staticmethod + def perf_counter_string_from_string_list(counterlist: [str]): + """ Creates a performance counter typename string. """ + if isinstance(counterlist, str): + print("Wrong format") + counterlist = eval(counterlist) + return "PAPIPerfLowLevel<" + ", ".join(counterlist) + ">" + + @staticmethod + def perf_counter_string(node): + """ Creates a performance counter typename string. """ + try: + assert isinstance(node.papi_counters, list) + return PerfUtils.perf_counter_string_from_string_list( + node.papi_counters) + except Exception as e: + return PerfUtils.perf_counter_string_from_string_list( + PerfSettings.perf_default_papi_counters()) + + @staticmethod + def read_available_perfcounters(): + from string import Template + import subprocess + + papi_avail_str = "papi_avail -a" + s = Template(Config.get("execution", "general", "execcmd")) + cmd = s.substitute( + host=Config.get("execution", "general", "host"), + command=papi_avail_str) + p = subprocess.Popen( + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True) + + stdout, _ = p.communicate(timeout=60) + + counter_num = re.search( + r"Number Hardware Counters[\s.]*:\s(?P[0-9]+)", + str(stdout)) + if counter_num: + counter_num = int(counter_num['num_cntr']) + print("Hardware counters: %s" % counter_num) + + print("PAPI preset events:") + # Find non-derived events first + non_derived = re.findall( + r"(?PPAPI_[0-9A-Z_]+)\s+0x[0-9a-zA-Z]+\s+No", + str(stdout)) + print("Non-Derived: ", non_derived) + + # Now all derived events + derived = re.findall( + r"(?PPAPI_[0-9A-Z_]+)\s+0x[0-9a-zA-Z]+\s+Yes", + str(stdout)) + print("Derived: ", derived) + + return (non_derived, derived, counter_num) + + @staticmethod + def collapse_sections(sections: list): + """ Combine sections with the same ID into one single section. """ + + seen = [] # Nodeids that were already collapsed + collapsed = [ + ] # The return value, consisting of all collapsed sections + + # Add all elements that were already collapsed + collapsed = [x for x in sections if x.was_collapsed] + + print("%d sections were already collapsed" % len(collapsed)) + + for _ in sections: + preselection = [ + x for x in sections + if not (x.nodeid, x.threadid) in seen and not x.was_collapsed + ] + if preselection == []: + break + target = preselection[0] + seen.append((target.nodeid, target.threadid)) + selection = [ + x for x in sections + if x.nodeid == target.nodeid and x.threadid == target.threadid + and x != target and not x.was_collapsed + ] + for y in selection: + target.addSection(y) + collapsed.append(target) + + target.was_collapsed = True # If selection is [] + + assert target.was_collapsed + + # Debug + removed_nodes = [x for x in sections if not (x in collapsed)] + print("Removed nodes: " + str([x.toJSON() for x in removed_nodes])) + print( + "Reduced from %d sections to %d" % (len(sections), len(collapsed))) + return collapsed + + @staticmethod + def print_instrumentation_output(data: str): + import json + print("print_instrumentation_output start") + # Regex for Section start + bytes: # Section start \(node (?P[0-9]+)\)\nbytes: (?P[0-9]+) + # Regex for general entries: # entry \((?P[0-9]+), (?P[0-9]+), (?P[0-9]+), (?P[0-9]+)\)\n((?P[0-9-]+): (?P[0-9-]+)\n)* + + print_values = False + + multirun_results = [] + multirun_supersections = [] + current_multirun_line = "" + sections = [] + supersection_node_id = None + supersections = [] + current_supersection = PerfUtils.SuperSection() + current_section = PerfUtils.Section() + current_entry = PerfUtils.Entry() + + state = PerfUtils.ParseStates.CONTROL + if isinstance(data, str): + lines = data.split('\n') + is_string_input = True + else: + lines = data + is_string_input = False + + line_num = 0 + for line in lines: + line_num = line_num + 1 + if not is_string_input: + line = line[:-1] # Chomp trailing newline + + if "multirun" in line: + # Multirun result + + try: + current_supersection.addEntry(current_entry) + except Exception as e: + print("Error occurred in line " + str(line_num) + "!") + raise e + + if current_section.is_valid(): + pass + + # Reset variables + current_section = PerfUtils.Section() + current_entry = PerfUtils.Entry() + + sections.extend(current_supersection.getSections()) + supersections.append(current_supersection) + + current_supersection = PerfUtils.SuperSection() + + if current_multirun_line != "" and sections != []: + multirun_results.append((current_multirun_line.replace( + "\n", ""), sections)) + if current_multirun_line != "" and supersections != []: + multirun_supersections.append( + (current_multirun_line.replace("\n", ""), + supersections)) + + current_multirun_line = line + sections = [] + supersections = [] + continue + if len(line) == 0: + continue + if line[0] == '#': + state = PerfUtils.ParseStates.CONTROL + if state == PerfUtils.ParseStates.CONTROL: + # First try: Entry + match = re.search( + r"# entry \((?P[0-9]+), (?P[0-9]+), (?P[0-9]+), (?P[0-9]+)\)", + line) + if match: + d = match.groupdict() + + try: + current_supersection.addEntry(current_entry) + except Exception as e: + print("Error occurred in line " + str(line_num) + "!") + raise e + + current_entry = PerfUtils.Entry() + + current_entry.nodeid = d['entry_node'] + current_entry.coreid = d['entry_thread'] + current_entry.iteration = d['entry_iteration'] + current_entry.flags = d['entry_flags'] + state = PerfUtils.ParseStates.VALUES + continue + + # Next try: Section header + match = re.search( + r"# Section start \(node (?P[0-9]+), core (?P[0-9]+)\)", + line) + if match: + #print("Matched Section Start") + d = match.groupdict() + + try: + current_supersection.addEntry(current_entry) + except Exception as e: + print("Error occurred in line " + str(line_num) + "!") + raise e + + current_entry = PerfUtils.Entry() + if (current_section.is_valid()): + #sections.append(current_section) + pass + current_section = PerfUtils.Section( + d['section_start_node'], d['section_start_core']) + current_supersection.addSection(current_section) + state = PerfUtils.ParseStates.SECTION_SIZE + continue + # Next try: Supersection header + match = re.search( + r"# Supersection start \(node (?P[0-9]+)\)", + line) + if match: + d = match.groupdict() + + supersection_node_id = d['section_start_node'] + + try: + current_supersection.addEntry(current_entry) + except Exception as e: + print("Error occurred in line " + str(line_num) + "!") + raise e + current_entry = PerfUtils.Entry() + + if (current_section.is_valid()): + #sections.append(current_section) + pass + + sections.extend(current_supersection.getSections()) + + supersections.append(current_supersection) + current_supersection = PerfUtils.SuperSection( + d['section_start_node']) + + current_section = PerfUtils.Section() # Clear the record + + state = PerfUtils.ParseStates.CONTROL + continue + # Next try: Section data moved + match = re.search(r"# moved_bytes: (?P[0-9]+)", + line) + if match: + d = match.groupdict() + current_section.bytes_moved = d['moved_bytes'] + continue + # Next try: Section data moved + match = re.search(r"# contention: (?P[0-9]+)", + line) + if match: + d = match.groupdict() + if int(d['contention']) != 0: + print( + "Contention: {cont}".format(cont=d['contention'])) + continue + # Next try: Entry (anonymous) + # (Should not happen) + print("Error, unexpected: anonymous entry %s" % line) + print(str(match)) + elif state == PerfUtils.ParseStates.VALUES: + match = re.search(r"(?P[0-9-]+): (?P[0-9-]+)", + line) + if match: + #print("Matched Value") + d = match.groupdict() + current_entry.add(d['counter'], d['value']) + else: + print("Failed to match expected values!") + continue + elif state == PerfUtils.ParseStates.SECTION_SIZE: + match = re.search(r"bytes: (?P[0-9-]+)", line) + if match: + #print("Matched Section Size") + d = match.groupdict() + current_section.datasize = d['bytes'] + else: + pass + continue + + try: + current_supersection.addEntry(current_entry) + except Exception as e: + print("Error occurred in line " + str(line_num) + "!") + raise e + + if current_section.is_valid(): + #sections.append(current_section) + pass + + #sections = PerfUtils.collapse_sections(sections) + #sections.extend(PerfUtils.collapse_sections(current_supersection.getSections())) + sections.extend(current_supersection.getSections()) + supersections.append(current_supersection) + multirun_results.append((current_multirun_line, sections)) + multirun_supersections.append((current_multirun_line, supersections)) + + # We'll filter invalid supersections later... + + print("Multirun length: " + str(len(multirun_results))) + + for o, s in multirun_results: + print("\tSection size: " + str(len(s))) + print("\t\tSection size: " + str(s[0].datasize)) + + try: + totstr = '{ "type": "PerfInfo", "payload": [' + ", ".join([ + '{"runopts": "%s", "data": [%s]}' % (o, ", ".join( + [x.toJSON() for x in r_supersections if x.is_valid()])) + for o, r_supersections in multirun_supersections + ]) + "]}" + + #totstr = '{ "type": "PerfInfo", "payload": [' + ", ".join([x.toJSON() for x in sections]) + "]}" + with open("perf.json", "w") as out: + out.write(totstr) + + # Debug CSV output + for idx, v in enumerate(multirun_supersections): + o, r_supersections = v + with open("perf%d.csv" % idx, "w") as out: + for x in r_supersections: + out.write(x.toCSVstring()) + + except: + import traceback + print("[Error] Failed to jsonify") + print(traceback.format_exc()) + + # Check if this runs + try: + for s in sections: + json.loads(s.toJSON()) + except: + print("[Error] JSON contains syntax errors!") + + if print_values: + print("==== ANALYSIS ====") + print("Got %d sections" % len(sections)) + for i, section in enumerate(sections): + print("Section %d (node %s)" % (i, section.nodeid)) + print("static memory movement (estimation): %s" % str( + section.datasize)) + print("runtime memory movement (measured): %s" % str( + section.bytes_moved)) + + max_thread_num = section.get_max_thread_num() + print("max_thread_num: %d" % max_thread_num) + tot_cyc = list() + tot_l3_miss = list() + tot_l2_miss = list() + for t in range(0, max_thread_num + 1): + ts = section.select_thread(t) + tc = ts.select_event('-2147483589') + # print("tc: %s\nsum(tc): %s" % (str(tc), str(sum(tc)))) + tot_cyc.append(sum(tc)) + + tl3 = ts.select_event('-2147483640') + tot_l3_miss.append(sum(tl3)) + + tl2 = ts.select_event('-2147483641') + tot_l2_miss.append(sum(tl2)) + + # Now we can get the balance + for i, t in enumerate(tot_cyc): + print("Thread %d took %d cycles" % (i, t)) + from statistics import stdev, mean + if len(tot_cyc) > 1 and mean(tot_cyc) != 0: + + print("stdev: %d" % stdev(tot_cyc)) + print("Balance: %f" % + (float(stdev(tot_cyc)) / float(mean(tot_cyc)))) + + for i, t in enumerate(tot_l3_miss): + print("Thread %d had %d L3 misses" % (i, t)) + sum_l3 = sum(tot_l3_miss) + print( + "%d bytes (presumably) accessed\n%d L3 misses over all threads\n%d bytes loaded from memory" + % (int(section.datasize), int(sum_l3), int(sum_l3) * 64)) + + for i, t in enumerate(tot_l2_miss): + print("Thread %d had %d L2 misses" % (i, t)) + sum_l2 = sum(tot_l2_miss) + print( + "%d bytes (presumably) accessed\n%d L2 misses over all threads\n%d bytes loaded from L3" + % (int(section.datasize), int(sum_l2), int(sum_l2) * 64)) + + +class PAPIUtil: + @staticmethod + def fallback_dict(available_events): + """ + Defines potential fallbacks for unavailable PAPI (preset) events + """ + d = dict() + #TCM => DCM + d['PAPI_L1_TCM'] = [ + x for x in ['PAPI_L1_DCM'] if x in available_events + ] + d['PAPI_L2_TCM'] = [ + x for x in ['PAPI_L2_DCM'] if x in available_events + ] + d['PAPI_L3_TCM'] = [ + x for x in ['PAPI_L3_DCM'] if x in available_events + ] + #DCM => TCM + d['PAPI_L1_DCM'] = [ + x for x in ['PAPI_L1_TCM'] if x in available_events + ] + d['PAPI_L2_DCM'] = [ + x for x in ['PAPI_L2_TCM'] if x in available_events + ] + d['PAPI_L3_DCM'] = [ + x for x in ['PAPI_L3_TCM'] if x in available_events + ] + + return d + + @staticmethod + def get_fallback(event, available_events): + """ + Returns a string identifying the most appropriate fallback for 'event', + or None if no such fallback exists. + """ + fbd = PAPIUtil.fallback_dict(available_events) + fb = fbd[event] + if (len(fb) == 0): + return None + else: + return fb[0] + + +class PerfMetaInfo: + """ Class dedicated to keep meta information about the generated code, in + particular line numbers. """ + + def __init__(self): + self.nodes = dict() # Maps nodes to their strings + self.lines = dict() # Maps nodes to their line number + + def add_node(self, node, string): + self.nodes[node] = string + + def has_node(self, node): + return node in self.nodes.keys() + + def resolve(self, codestr: str): + """ Maps all entries in self.node to line numbers """ + index = 0 + line = 1 + print("self.nodes: %s\ntype: %s" % (self.nodes, type(self.nodes))) + for key, value in self.nodes.items(): + pos = codestr.find(value, index) + if pos == -1: + # We will not accept this. This should only ever occur if some + # part of the program pretty-prints code. + assert False + sublines = codestr.count('\n', index, pos) + line += sublines + index = pos + # We store the current line back to self.lines + self.lines[key] = line + + def analyze(self, vectorizer_output: str): + """ Checks if a certain operation or a segment within a region of an + operation was vectorized. """ + # We only match calls originating from ./src/cpu/*, but it might still + # include some of the instrumentation. Consider running this on + # non-instrumented code instead + data = re.findall( + r".*?src/cpu/(?P[^:]*):(?P[\d]*):(?P[\d]*): (?P[^\n]*)", + vectorizer_output) + + print("data is:\n%s" % data) + + print("Node information is\n%s\n" % self.nodes) + print("Line information is\n%s\n" % self.lines) + + ret = dict( + ) # We return a dict of node -> [(file, line, col, Message)] + + first = True + tmp = (None, None) + for key, value in self.lines.items(): + # We now find for each key the value of their respective start + # (exception: MapExit, where the end counts) + # Then, we associate the message to that key + if not first: + prevkey, prevval = tmp + for file, line, col, message in data: + if int(prevval) <= int(line) and int(line) < int(value): + # Valid entry + if not (prevkey in ret.keys()): + ret[prevkey] = list() + ret[prevkey].append((file, line, col, message)) + else: + first = False + + tmp = (key, value) + + # For the last entry: + prevkey, prevval = tmp + if prevkey != None: + for file, line, col, message in data: + if int(prevval) <= int(line): + # Valid entry + if not (prevkey in ret.keys()): + ret[prevkey] = list() + ret[prevkey].append((file, line, col, message)) + + print("ret:\n%s" % ret) + + return ret + + +class PerfMetaInfoStatic: + info = PerfMetaInfo() + + +class PerfPAPIInfo: + """ Class used to keep information about the remote, most notably the + allowed configurations. """ + + def __init__(self): + self.num_hw_counters = -1 + self.preset_cost = dict() # event: str -> num_counters: int + self.cached_host = "" + self.memspeed = 20.0 # B/c + + def set_memspeed(self, speed): + self.memspeed = speed + + def load_info(self): + """ Load information about the counters from remote. """ + from string import Template + import subprocess + + print("Loading counter info from remote...") + + if self.cached_host == Config.get("execution", "general", "host"): + return # Do not run this every time, just the first time + else: + # else reset + self.num_hw_counters = -1 + self.preset_cost = dict() + + non_derived, derived, num_ctrs = PerfUtils.read_available_perfcounters( + ) + self.num_hw_counters = num_ctrs + + # Having these events, the non_derived (by definition) use 1 counter + for x in non_derived: + self.preset_cost[x] = 1 + + # For the others, we have to request some more information. + # NOTE: This could be moved into a shell script and run on remote + # if issuing many commands is too slow + for index, x in enumerate(derived): + print("%d/%d Elements...\r" % (index + 1, len(derived)), end='') + papi_avail_str = 'papi_avail -e %s | grep --color=never "Number of Native Events"' % x + s = Template(Config.get("execution", "general", "execcmd")) + cmd = s.substitute( + host=Config.get("execution", "general", "host"), + command=papi_avail_str) + p = subprocess.Popen( + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True) + + stdout, _ = p.communicate(timeout=60) + + counter_num_grp = re.search( + r"Number of Native Events:\s*(?P\d+)", str(stdout)) + if counter_num_grp != None: + self.preset_cost[x] = int(counter_num_grp['num']) + else: + print("\nError: Expected to find a number here...") + + self.cached_host = Config.get("execution", "general", "host") + print("\nDone") + + def check_counters(self, counter_lists: list): + """ Checks if the specified counter groups can be used. """ + assert self.cached_host != "" + + counter_lists_set = list() + + for x in counter_lists: + if not x in counter_lists_set: + counter_lists_set.append(x) + for counter_list in counter_lists_set: + sum_counters = 0 + for c in counter_list: + try: + sum_counters += self.preset_cost[c] + except: + # This should only happen with Native Events + print( + "check_counters failed with reason: Unknown/unsupported event code specified: %s" + % c) + return False + if sum_counters > self.num_hw_counters: + print( + "check_counters failed with reason: Not enough hardware counters to support specified events" + ) + return False + return True + + +class PerfPAPIInfoStatic: + info = PerfPAPIInfo() diff --git a/dace/codegen/prettycode.py b/dace/codegen/prettycode.py new file mode 100644 index 0000000000..f58e4a707b --- /dev/null +++ b/dace/codegen/prettycode.py @@ -0,0 +1,70 @@ +""" Code I/O stream that automates indentation and mapping of code to SDFG + nodes. """ + +from six import StringIO +from dace.config import Config + + +class CodeIOStream(StringIO): + """ Code I/O stream that automates indentation and mapping of code to SDFG + nodes. """ + + def __init__(self, base_indentation=0): + super(CodeIOStream, self).__init__() + self._indent = 0 + self._spaces = int(Config.get('compiler', 'indentation_spaces')) + + def write(self, contents, sdfg=None, state_id=None, node_id=None): + # Delete single trailing newline, as this will be implicitly inserted + # anyway + if contents: + if contents[-1] == "\n": + lines = contents[:-1].split("\n") + else: + lines = contents.split('\n') + else: + lines = contents + + # If SDFG/state/node location is given, annotate this line + if sdfg is not None: + location_identifier = ' ////__DACE:%s' % sdfg.name + if state_id is not None: + location_identifier += ':' + str(state_id) + if node_id is not None: + if not isinstance(node_id, list): + node_id = [node_id] + for i, nid in enumerate(node_id): + if not isinstance(nid, int): + node_id[i] = sdfg.nodes()[state_id].node_id(nid) + location_identifier += ':' + ','.join( + [str(nid) for nid in node_id]) + else: + location_identifier = '' + + # Write each line separately + for line in lines: + opening_braces = line.count('{') + closing_braces = line.count('}') + brace_balance = opening_braces - closing_braces + + # Write line and then change indentation + if brace_balance < 0: + self._indent += brace_balance + + codeline = self._indent * self._spaces * ' ' + line.strip() + + # Location identifier is written at character 81 and on, find out + # how many spaces we need to add for that + loc_spaces = max(80 - len(codeline), 2) + + super(CodeIOStream, self).write(codeline + loc_spaces * ' ' + + location_identifier + '\n') + if brace_balance > 0: + self._indent += brace_balance + + # If indentation failed, warn user + if self._indent < -1: + super(CodeIOStream, self).write( + '///WARNING: Indentation failure! This probably ' + + 'indicates an error in the SDFG.\n') + self._indent = 0 diff --git a/dace/codegen/targets/__init__.py b/dace/codegen/targets/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/dace/codegen/targets/__init__.py @@ -0,0 +1 @@ + diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py new file mode 100644 index 0000000000..65002bc74b --- /dev/null +++ b/dace/codegen/targets/cpu.py @@ -0,0 +1,2618 @@ +import ast +import copy +import functools +import itertools +import sympy as sp +from six import StringIO + +from dace.codegen import cppunparse + +import dace +from dace.config import Config +from dace.frontend import operations +from dace import data, subsets, symbolic, types, memlet as mmlt +from dace.codegen.prettycode import CodeIOStream +from dace.codegen.codeobject import CodeObject +from dace.codegen.targets import framecode +from dace.codegen.targets.target import (TargetCodeGenerator, make_absolute, + DefinedType) +from dace.graph import nodes, nxutil +from dace.sdfg import ScopeSubgraphView, SDFG, scope_contains_scope, find_input_arraynode, find_output_arraynode, is_devicelevel + +from dace.frontend.python.astutils import ExtNodeTransformer, rname, unparse +from dace.properties import LambdaProperty + +from dace.codegen.instrumentation.perfsettings import PerfSettings, PerfUtils, PerfMetaInfo, PerfMetaInfoStatic + +_REDUCTION_TYPE_TO_OPENMP = { + types.ReductionType.Max: 'max', + types.ReductionType.Min: 'min', + types.ReductionType.Sum: '+', + types.ReductionType.Product: '*', + types.ReductionType.Bitwise_And: '&', + types.ReductionType.Logical_And: '&&', + types.ReductionType.Bitwise_Or: '|', + types.ReductionType.Logical_Or: '||', + types.ReductionType.Bitwise_Xor: '^', +} + + +class CPUCodeGen(TargetCodeGenerator): + """ SDFG CPU code generator. """ + + title = 'CPU' + target_name = 'cpu' + language = 'cpp' + + def __init__(self, frame_codegen, sdfg): + self._frame = frame_codegen + self._dispatcher = frame_codegen.dispatcher + dispatcher = self._dispatcher + + self._locals = cppunparse.CPPLocals() + # Scope depth (for use of the 'auto' keyword when + # defining locals) + self._ldepth = 0 + + # FIXME: this allows other code generators to change the CPU + # behavior to assume that arrays point to packed types, thus dividing + # all addresess by the vector length. + self._packed_types = False + + # Keep track of traversed nodes + self._generated_nodes = set() + self._allocated_arrays = set() + # Keeps track of generated connectors, so we know how to access them in + # nested scopes + for name, arg_type in sdfg.arglist().items(): + if (isinstance(arg_type, dace.data.Scalar) + or isinstance(arg_type, dace.types.typeclass)): + self._dispatcher.defined_vars.add(name, DefinedType.Scalar) + elif isinstance(arg_type, dace.data.Array): + self._dispatcher.defined_vars.add(name, DefinedType.Pointer) + elif isinstance(arg_type, dace.data.Stream): + if arg_type.is_stream_array(): + self._dispatcher.defined_vars.add(name, + DefinedType.StreamArray) + else: + self._dispatcher.defined_vars.add(name, DefinedType.Stream) + else: + raise TypeError("Unrecognized argument type: {}".format( + type(arg_type).__name__)) + + # Register dispatchers + dispatcher.register_node_dispatcher(self) + dispatcher.register_map_dispatcher( + [types.ScheduleType.CPU_Multicore, types.ScheduleType.Sequential], + self) + + cpu_storage = [ + types.StorageType.CPU_Heap, types.StorageType.CPU_Pinned, + types.StorageType.CPU_Stack, types.StorageType.Register + ] + dispatcher.register_array_dispatcher(cpu_storage, self) + + # Register CPU copies (all internal pairs) + for src_storage, dst_storage in itertools.product( + cpu_storage, cpu_storage): + dispatcher.register_copy_dispatcher(src_storage, dst_storage, None, + self) + + @staticmethod + def cmake_options(): + compiler = make_absolute(Config.get("compiler", "cpu", "executable")) + flags = Config.get("compiler", "cpu", "args") + flags += Config.get("compiler", "cpu", "additional_args") + + # Args for vectorization output + if PerfSettings.perf_enable_vectorization_analysis(): + flags += " -fopt-info-vec-optimized-missed=vecreport.txt " + + options = [ + "-DCMAKE_CXX_COMPILER=\"{}\"".format(compiler), + "-DCMAKE_CXX_FLAGS=\"{}\"".format(flags), + ] + return options + + def get_generated_codeobjects(self): + # CPU target generates inline code + return [] + + @property + def has_initializer(self): + return False + + @property + def has_finalizer(self): + return False + + def generate_scope(self, sdfg: SDFG, dfg_scope: ScopeSubgraphView, + state_id, function_stream, callsite_stream): + entry_node = dfg_scope.source_nodes()[0] + presynchronize_streams(sdfg, dfg_scope, state_id, entry_node, + callsite_stream) + + self.generate_node(sdfg, dfg_scope, state_id, entry_node, + function_stream, callsite_stream) + self._dispatcher.dispatch_subgraph( + sdfg, + dfg_scope, + state_id, + function_stream, + callsite_stream, + skip_entry_node=True) + + def generate_node(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + # Dynamically obtain node generator according to class name + gen = getattr(self, '_generate_' + type(node).__name__) + + gen(sdfg, dfg, state_id, node, function_stream, callsite_stream) + + # Mark node as "generated" + self._generated_nodes.add(node) + + self._locals.clear_scope(self._ldepth + 1) + + def allocate_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + name = node.data + nodedesc = node.desc(sdfg) + if ((state_id, node.data) in self._allocated_arrays + or (None, node.data) in self._allocated_arrays + or nodedesc.transient == False): + return + self._allocated_arrays.add((state_id, node.data)) + + # Compute array size + arrsize = ' * '.join([sym2cpp(s) for s in nodedesc.strides]) + + if isinstance(nodedesc, data.Scalar): + callsite_stream.write("%s %s;\n" % (nodedesc.dtype.ctype, name), + sdfg, state_id, node) + self._dispatcher.defined_vars.add(name, DefinedType.Scalar) + elif isinstance(nodedesc, data.Stream): + ################################################################### + # Stream directly connected to an array + + if is_array_stream_view(sdfg, dfg, node): + if state_id is None: + raise SyntaxError( + 'Stream-view of array may not be defined ' + 'in more than one state') + + arrnode = sdfg.arrays[nodedesc.sink] + state = sdfg.nodes()[state_id] + edges = state.out_edges(node) + if len(edges) > 1: + raise NotImplementedError('Cannot handle streams writing ' + 'to multiple arrays.') + + memlet_path = state.memlet_path(edges[0]) + # Allocate the array before its stream view, if necessary + self.allocate_array(sdfg, dfg, state_id, memlet_path[-1].dst, + function_stream, callsite_stream) + + array_expr = self.copy_expr(sdfg, nodedesc.sink, edges[0].data) + threadlocal = '' + threadlocal_stores = [ + types.StorageType.CPU_Stack, types.StorageType.Register + ] + if (sdfg.arrays[nodedesc.sink].storage in threadlocal_stores + or nodedesc.storage in threadlocal_stores): + threadlocal = 'Threadlocal' + callsite_stream.write( + 'dace::ArrayStreamView%s<%s> %s (%s);\n' % + (threadlocal, arrnode.dtype.ctype, name, array_expr), sdfg, + state_id, node) + self._dispatcher.defined_vars.add(name, DefinedType.Stream) + return + + ################################################################### + # Regular stream + + dtype = "dace::vec<{}, {}>".format(nodedesc.dtype.ctype, + sym2cpp(nodedesc.veclen)) + + if nodedesc.buffer_size != 0: + definition = "dace::Stream<{}> {}({});".format( + dtype, name, nodedesc.buffer_size) + else: + definition = "dace::Stream<{}> {};".format(dtype, name) + + callsite_stream.write(definition, sdfg, state_id, node) + self._dispatcher.defined_vars.add(name, DefinedType.Stream) + + elif (nodedesc.storage == types.StorageType.CPU_Heap + or nodedesc.storage == types.StorageType.Immaterial + ): # TODO: immaterial arrays should not allocate memory + callsite_stream.write( + "%s *%s = new %s DACE_ALIGN(64)[%s];\n" % + (nodedesc.dtype.ctype, name, nodedesc.dtype.ctype, arrsize), + sdfg, state_id, node) + self._dispatcher.defined_vars.add(name, DefinedType.Pointer) + if node.setzero: + callsite_stream.write('memset(%s, 0, sizeof(%s)*%s);' % + (name, nodedesc.dtype.ctype, arrsize)) + return + elif (nodedesc.storage == types.StorageType.CPU_Stack + or nodedesc.storage == types.StorageType.Register): + if node.setzero: + callsite_stream.write( + "%s %s[%s] DACE_ALIGN(64) = {0};\n" % + (nodedesc.dtype.ctype, name, arrsize), sdfg, state_id, + node) + self._dispatcher.defined_vars.add(name, DefinedType.Pointer) + return + callsite_stream.write( + "%s %s[%s] DACE_ALIGN(64);\n" % + (nodedesc.dtype.ctype, name, arrsize), sdfg, state_id, node) + self._dispatcher.defined_vars.add(name, DefinedType.Pointer) + return + else: + raise NotImplementedError('Unimplemented storage type ' + + str(nodedesc.storage)) + + def initialize_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + if isinstance(dfg, SDFG): + result = StringIO() + for sid, state in enumerate(dfg.nodes()): + if node in state.nodes(): + self.initialize_array(sdfg, state, sid, node, + function_stream, callsite_stream) + break + return + + parent_node = dfg.scope_dict()[node] + nodedesc = node.desc(sdfg) + name = node.data + + # Traverse the DFG, looking for WCR with an identity element + def traverse(u, uconn, v, vconn, d): + if d.wcr: + if d.data == name: + if d.wcr_identity is not None: + return d.wcr_identity + return None + + identity = None + if parent_node is not None: + for u, uconn, v, vconn, d, s in nxutil.traverse_sdfg_scope( + dfg, parent_node): + identity = traverse(u, uconn, v, vconn, d) + if identity is not None: break + else: + for u, uconn, v, vconn, d in dfg.edges(): + identity = traverse(u, uconn, v, vconn, d) + if identity is not None: break + + if identity is None: + return + + # If we should generate an initialization expression + if isinstance(nodedesc, data.Scalar): + callsite_stream.write('%s = %s;\n' % (name, sym2cpp(identity)), + sdfg, state_id, node) + return + + params = [name, sym2cpp(identity)] + shape = [sym2cpp(s) for s in nodedesc.shape] + params.append(' * '.join(shape)) + + # Faster + if identity == 0: + params[-1] += ' * sizeof(%s[0])' % name + callsite_stream.write('memset(%s);\n' % (', '.join(params)), sdfg, + state_id, node) + return + + callsite_stream.write('dace::InitArray(%s);\n' % (', '.join(params)), + sdfg, state_id, node) + + def deallocate_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + nodedesc = node.desc(sdfg) + if isinstance(nodedesc, data.Scalar): + return + elif isinstance(nodedesc, data.Stream): + return + elif nodedesc.storage == types.StorageType.CPU_Heap: + callsite_stream.write("delete[] %s;\n" % node.data, sdfg, state_id, + node) + else: + return + + def copy_memory(self, sdfg, dfg, state_id, src_node, dst_node, edge, + function_stream, callsite_stream): + if isinstance(src_node, nodes.Tasklet): + src_storage = types.StorageType.Register + try: + src_parent = dfg.scope_dict()[src_node] + except KeyError: + src_parent = None + dst_schedule = (None + if src_parent is None else src_parent.map.schedule) + else: + src_storage = src_node.desc(sdfg).storage + + if isinstance(dst_node, nodes.Tasklet): + dst_storage = types.StorageType.Register + else: + dst_storage = dst_node.desc(sdfg).storage + + try: + dst_parent = dfg.scope_dict()[dst_node] + except KeyError: + dst_parent = None + dst_schedule = None if dst_parent is None else dst_parent.map.schedule + + state_dfg = sdfg.nodes()[state_id] + + # Emit actual copy + self._emit_copy(sdfg, state_id, src_node, src_storage, dst_node, + dst_storage, dst_schedule, edge, state_dfg, + callsite_stream) + + def _emit_copy(self, sdfg, state_id, src_node, src_storage, dst_node, + dst_storage, dst_schedule, edge, dfg, stream): + u, uconn, v, vconn, memlet = edge + + ############################################################# + # Instrumentation: Pre-copy + + # For perfcounters, we have to make sure that: + # 1) No other measurements are done for the containing scope (no map operation containing this copy is instrumented) + src_instrumented = PerfUtils.has_surrounding_perfcounters( + src_node, dfg) + dst_instrumented = PerfUtils.has_surrounding_perfcounters( + dst_node, dfg) + + # From cuda.py + cpu_storage_types = [ + types.StorageType.CPU_Heap, types.StorageType.CPU_Stack, + types.StorageType.CPU_Pinned, types.StorageType.Register + ] + + perf_cpu_only = (src_storage in cpu_storage_types) and ( + dst_storage in cpu_storage_types) + + perf_should_instrument = PerfSettings.perf_enable_instrumentation_for( + sdfg) and (not src_instrumented) and ( + not dst_instrumented) and perf_cpu_only + + ############################################################# + + # Determine memlet directionality + if (isinstance(src_node, nodes.AccessNode) + and memlet.data == src_node.data): + write = True + elif (isinstance(dst_node, nodes.AccessNode) + and memlet.data == dst_node.data): + write = False + elif isinstance(src_node, nodes.CodeNode) and isinstance( + dst_node, nodes.CodeNode): + # Code->Code copy (not read nor write) + raise RuntimeError( + 'Copying between code nodes is only supported as' + ' part of the participating nodes') + else: + raise LookupError('Memlet does not point to any of the nodes') + + if isinstance(dst_node, nodes.Tasklet): + # Copy into tasklet + stream.write( + ' ' + self.memlet_definition(sdfg, memlet, False, vconn), + sdfg, state_id, [src_node, dst_node]) + return + elif isinstance(src_node, nodes.Tasklet): + # Copy out of tasklet + stream.write( + ' ' + self.memlet_definition(sdfg, memlet, True, uconn), + sdfg, state_id, [src_node, dst_node]) + return + else: # Copy array-to-array + src_nodedesc = src_node.desc(sdfg) + dst_nodedesc = dst_node.desc(sdfg) + + if write: + vconn = dst_node.data + ctype = 'dace::vec<%s, %d>' % (dst_nodedesc.dtype.ctype, + memlet.veclen) + + ############################################# + # Corner cases + + # Writing one index + if isinstance(memlet.subset, + subsets.Indices) and memlet.wcr is None: + stream.write( + '%s = %s;' % (vconn, self.memlet_ctor( + sdfg, memlet, False)), sdfg, state_id, + [src_node, dst_node]) + return + # Writing from/to a stream + if (isinstance(sdfg.arrays[memlet.data], data.Stream) or \ + (isinstance(src_node, nodes.AccessNode) and isinstance(src_nodedesc, + data.Stream))): + # Identify whether a stream is writing to an array + if (isinstance(dst_nodedesc, (data.Scalar, data.Array)) + and isinstance(src_nodedesc, data.Stream)): + return # Do nothing (handled by ArrayStreamView) + + # Array -> Stream - push bulk + if (isinstance(src_nodedesc, (data.Scalar, data.Array)) + and isinstance(dst_nodedesc, data.Stream)): + if hasattr(src_nodedesc, 'src'): # ArrayStreamView + stream.write( + '{s}.push({arr});'.format( + s=dst_node.data, arr=src_nodedesc.src), sdfg, + state_id, [src_node, dst_node]) + else: + copysize = ' * '.join( + [sym2cpp(s) for s in memlet.subset.size()]) + stream.write( + '{s}.push({arr}, {size});'.format( + s=dst_node.data, + arr=src_node.data, + size=copysize), sdfg, state_id, + [src_node, dst_node]) + return + else: + # Unknown case + raise NotImplementedError + + ############################################# + + state_dfg = sdfg.nodes()[state_id] + + copy_shape, src_strides, dst_strides, src_expr, dst_expr = ( + self.memlet_copy_to_absolute_strides(sdfg, memlet, src_node, + dst_node)) + + # Which numbers to include in the variable argument part + dynshape, dynsrc, dyndst = 1, 1, 1 + + # Dynamic copy dimensions + if any(symbolic.issymbolic(s, sdfg.constants) for s in copy_shape): + copy_tmpl = 'Dynamic<{type}, {veclen}, {aligned}, {dims}>'.format( + type=ctype, + veclen=1, # Taken care of in "type" + aligned='false', + dims=len(copy_shape)) + else: # Static copy dimensions + copy_tmpl = '<{type}, {veclen}, {aligned}, {dims}>'.format( + type=ctype, + veclen=1, # Taken care of in "type" + aligned='false', + dims=', '.join(sym2cpp(copy_shape))) + dynshape = 0 + + # Constant src/dst dimensions + if not any( + symbolic.issymbolic(s, sdfg.constants) + for s in dst_strides): + # Constant destination + shape_tmpl = 'template ConstDst<%s>' % ', '.join( + sym2cpp(dst_strides)) + dyndst = 0 + elif not any( + symbolic.issymbolic(s, sdfg.constants) + for s in src_strides): + # Constant source + shape_tmpl = 'template ConstSrc<%s>' % ', '.join( + sym2cpp(src_strides)) + dynsrc = 0 + else: + # Both dynamic + shape_tmpl = 'Dynamic' + + # Parameter pack handling + stride_tmpl_args = [0] * ( + dynshape + dynsrc + dyndst) * len(copy_shape) + j = 0 + for shape, src, dst in zip(copy_shape, src_strides, dst_strides): + if dynshape > 0: + stride_tmpl_args[j] = shape + j += 1 + if dynsrc > 0: + stride_tmpl_args[j] = src + j += 1 + if dyndst > 0: + stride_tmpl_args[j] = dst + j += 1 + + copy_args = ([src_expr, dst_expr] + ([] if memlet.wcr is None else + [unparse_cr(memlet.wcr)]) + + sym2cpp(stride_tmpl_args)) + + ############################################################# + # Instrumentation: Pre-copy 2 + unique_cpy_id = PerfSettings.get_unique_number() + + if perf_should_instrument: + fac3 = ' * '.join(sym2cpp(copy_shape)) + " / " + '/'.join( + sym2cpp(dst_strides)) + copy_size = "sizeof(%s) * %s * (%s)" % (ctype, memlet.veclen, + fac3) + node_id = PerfUtils.unified_id(dfg.node_id(dst_node), state_id) + # Mark a section start (this is not really a section in itself (it would be a section with 1 entry)) + stream.write( + "__perf_store.markSectionStart(%d, (long long)%s, PAPI_thread_id());\n" + % (node_id, copy_size), sdfg, state_id, + [src_node, dst_node]) + stream.write(( + "dace_perf::{pcs} __perf_cpy_{nodeid}_{unique_id};\n" + + "auto& __vs_cpy_{nodeid}_{unique_id} = __perf_store.getNewValueSet(__perf_cpy_{nodeid}_{unique_id}, {nodeid}, PAPI_thread_id(), {size}, dace_perf::ValueSetType::Copy);\n" + + "__perf_cpy_{nodeid}_{unique_id}.enterCritical();\n" + ).format( + pcs=PerfUtils.perf_counter_string(dst_node), + nodeid=node_id, + unique_id=unique_cpy_id, + size=copy_size), sdfg, state_id, [src_node, dst_node]) + ############################################################# + + nc = True + if memlet.wcr is not None: + nc = not is_write_conflicted(dfg, edge) + if nc: + stream.write( + """ + dace::CopyND{copy_tmpl}::{shape_tmpl}::{copy_func}( + {copy_args});""".format( + copy_tmpl=copy_tmpl, + shape_tmpl=shape_tmpl, + copy_func='Copy' + if memlet.wcr is None else 'Accumulate', + copy_args=', '.join(copy_args)), sdfg, state_id, + [src_node, dst_node]) + else: # Conflicted WCR + if dynshape == 1: + raise NotImplementedError( + 'Accumulation of dynamically-shaped ' + 'arrays not yet implemented') + elif copy_shape == [ + 1 + ]: # Special case: accumulating one element + dst_expr = self.memlet_view_ctor(sdfg, memlet, True) + stream.write( + write_and_resolve_expr(memlet, nc, dst_expr, + '*(' + src_expr + ')'), sdfg, + state_id, [src_node, dst_node]) + else: + raise NotImplementedError('Accumulation of arrays ' + 'with WCR not yet implemented') + + ############################################################# + # Instrumentation: Post-copy + if perf_should_instrument: + stream.write(("__perf_cpy_%d_%d.leaveCritical(__vs_cpy_%d_%d);\n") + % (node_id, unique_cpy_id, node_id, unique_cpy_id), + sdfg, state_id, [src_node, dst_node]) + ############################################################# + + ########################################################################### + # Memlet handling + + def process_out_memlets(self, sdfg, state_id, node, dfg, dispatcher, + result, locals_defined, function_stream): + + scope_dict = sdfg.nodes()[state_id].scope_dict() + + for edge in dfg.out_edges(node): + _, uconn, v, _, memlet = edge + dst_node = dfg.memlet_path(edge)[-1].dst + + # Target is neither a data nor a tasklet node + if (isinstance(node, nodes.AccessNode) + and (not isinstance(dst_node, nodes.AccessNode) + and not isinstance(dst_node, nodes.CodeNode))): + continue + + # Skip array->code (will be handled as a tasklet input) + if isinstance(node, nodes.AccessNode) and isinstance( + v, nodes.CodeNode): + continue + + # code->code (e.g., tasklet to tasklet) + if isinstance(v, nodes.CodeNode): + shared_data_name = 's%d_n%d%s_n%d%s' % ( + state_id, dfg.node_id(edge.src), edge.src_conn, + dfg.node_id(edge.dst), edge.dst_conn) + result.write('__%s = %s;' % (shared_data_name, edge.src_conn), + sdfg, state_id, [edge.src, edge.dst]) + continue + + # If the memlet is not pointing to a data node (e.g. tasklet), then + # the tasklet will take care of the copy + if not isinstance(dst_node, nodes.AccessNode): + continue + # If the memlet is pointing into an array in an inner scope, then + # the inner scope (i.e., the output array) must handle it + if (scope_dict[node] != scope_dict[dst_node] + and scope_contains_scope(scope_dict, node, dst_node)): + continue + + # Array to tasklet (path longer than 1, handled at tasklet entry) + if node == dst_node: + continue + + # Tasklet -> array + if isinstance(node, nodes.CodeNode): + if not uconn: + raise SyntaxError( + 'Cannot copy memlet without a local connector: {} to {}' + .format(str(edge.src), str(edge.dst))) + + try: + positive_accesses = bool(memlet.num_accesses >= 0) + except TypeError: + positive_accesses = False + + if memlet.subset.data_dims() == 0 and positive_accesses: + out_local_name = ' __' + uconn + in_local_name = uconn + if not locals_defined: + out_local_name = self.memlet_ctor(sdfg, memlet, True) + in_memlets = [ + d for _, _, _, _, d in dfg.in_edges(node) + ] + assert len(in_memlets) == 1 + in_local_name = self.memlet_ctor( + sdfg, in_memlets[0], False) + + state_dfg = sdfg.nodes()[state_id] + + if memlet.wcr is not None: + nc = not is_write_conflicted(dfg, edge) + result.write( + write_and_resolve_expr(memlet, nc, out_local_name, + in_local_name), sdfg, + state_id, node) + else: + result.write( + '%s.write(%s);\n' % (out_local_name, + in_local_name), sdfg, + state_id, node) + # Dispatch array-to-array outgoing copies here + elif isinstance(node, nodes.AccessNode): + if dst_node != node and not isinstance(dst_node, + nodes.Tasklet): + dispatcher.dispatch_copy(node, dst_node, edge, sdfg, dfg, + state_id, function_stream, result) + + def memlet_view_ctor(self, sdfg, memlet, is_output): + memlet_params = [] + + memlet_name = memlet.data + def_type = self._dispatcher.defined_vars.get(memlet_name) + + if def_type == DefinedType.Pointer: + memlet_expr = memlet_name # Common case + elif (def_type == DefinedType.Scalar + or def_type == DefinedType.ScalarView): + memlet_expr = '&' + memlet_name + elif def_type == DefinedType.ArrayView: + memlet_expr = memlet_name + ".ptr()" + else: + raise TypeError("Unsupported connector type {}".format(def_type)) + + if isinstance(memlet.subset, subsets.Indices): + + # FIXME: _packed_types influences how this offset is + # generated from the FPGA codegen. We should find a nicer solution. + if self._packed_types is True: + offset = cpp_array_expr( + sdfg, memlet, False, packed_veclen=memlet.veclen) + else: + offset = cpp_array_expr(sdfg, memlet, False) + + # Compute address + memlet_params.append(memlet_expr + ' + ' + offset) + dims = 0 + + else: + + if isinstance(memlet.subset, subsets.Range): + + dims = len(memlet.subset.ranges) + + # FIXME: _packed_types influences how this offset is + # generated from the FPGA codegen. We should find a nicer + # solution. + if self._packed_types is True: + offset = cpp_offset_expr( + sdfg.arrays[memlet.data], + memlet.subset, + packed_veclen=memlet.veclen) + else: + offset = cpp_offset_expr(sdfg.arrays[memlet.data], + memlet.subset) + if offset == "0": + memlet_params.append(memlet_expr) + else: + if (def_type not in [ + DefinedType.Pointer, DefinedType.ArrayView + ]): + raise dace.codegen.codegen.CodegenError( + "Cannot offset address of connector {} of type {}". + format(memlet_name, def_type)) + memlet_params.append(memlet_expr + ' + ' + offset) + + # Dimensions to remove from view (due to having one value) + indexdims = [] + + # Figure out dimensions for scalar version + for dim, (rb, re, rs) in enumerate(memlet.subset.ranges): + try: + if (re - rb) == 0: + indexdims.append(dim) + except TypeError: # cannot determine truth value of Relational + pass + + # Remove index (one scalar) dimensions + dims -= len(indexdims) + + if dims > 0: + strides = memlet.subset.absolute_strides( + sdfg.arrays[memlet.data].strides) + # Filter out index dims + strides = [ + s for i, s in enumerate(strides) if i not in indexdims + ] + # FIXME: _packed_types influences how this offset is + # generated from the FPGA codegen. We should find a nicer + # solution. + if self._packed_types and memlet.veclen > 1: + for i in range(len(strides) - 1): + strides[i] /= memlet.veclen + memlet_params.extend(sym2cpp(strides)) + dims = memlet.subset.data_dims() + + else: + raise RuntimeError( + 'Memlet type "%s" not implemented' % memlet.subset) + + if memlet.num_accesses == 1: + num_accesses_str = "1" + else: # symbolic.issymbolic(memlet.num_accesses, sdfg.constants): + num_accesses_str = 'dace::NA_RUNTIME' + + return 'dace::ArrayView%s<%s, %d, %s, %s> (%s)' % ( + "Out" + if is_output else "In", sdfg.arrays[memlet.data].dtype.ctype, dims, + sym2cpp(memlet.veclen), num_accesses_str, ', '.join(memlet_params)) + + def memlet_definition(self, sdfg, memlet, output, local_name): + result = ('auto __%s = ' % local_name + self.memlet_ctor( + sdfg, memlet, output) + ';\n') + + # Allocate variable type + memlet_type = 'dace::vec<%s, %s>' % ( + sdfg.arrays[memlet.data].dtype.ctype, sym2cpp(memlet.veclen)) + + var_type = self._dispatcher.defined_vars.get(memlet.data) + + # ** Concerning aligned vs. non-aligned values: + # We prefer aligned values, so in every case where we are assigning to + # a local _value_, we explicitly assign to an aligned type + # (memlet_type). In all other cases, where we need either a pointer or + # a reference, typically due to variable number of accesses, we have to + # use the underlying type of the ArrayView, be it aligned or unaligned, + # to avoid runtime crashes. We use auto for this, so the ArrayView can + # return whatever it supports. + + if var_type == DefinedType.Scalar: + if memlet.num_accesses == 1: + if not output: + # We can pre-read the value + result += "{} {} = __{}.val<{}>();".format( + memlet_type, local_name, local_name, memlet.veclen) + else: + # The value will be written during the tasklet, and will be + # automatically written out after + result += "{} {};".format(memlet_type, local_name) + self._dispatcher.defined_vars.add(local_name, + DefinedType.Scalar) + elif memlet.num_accesses == -1: + if output: + # Variable number of writes: get reference to the target of + # the view to reflect writes at the data + result += "auto &{} = __{}.ref<{}>();".format( + local_name, local_name, memlet.veclen) + else: + # Variable number of reads: get a const reference that can + # be read if necessary + result += "auto const &{} = __{}.ref<{}>();".format( + local_name, local_name, memlet.veclen) + self._dispatcher.defined_vars.add(local_name, + DefinedType.Scalar) + else: + raise dace.codegen.codegen.CodegenError( + "Unsupported number of accesses {} for scalar {}".format( + memlet.num_accesses, local_name)) + elif var_type == DefinedType.Pointer: + if memlet.num_accesses == 1: + if output: + result += "{} {};".format(memlet_type, local_name) + else: + result += "{} {} = __{}.val<{}>();".format( + memlet_type, local_name, local_name, memlet.veclen) + self._dispatcher.defined_vars.add(local_name, + DefinedType.Scalar) + else: + if memlet.subset.data_dims() == 0: + # Forward ArrayView + result += "auto &{} = __{}.ref<{}>();".format( + local_name, local_name, memlet.veclen) + self._dispatcher.defined_vars.add(local_name, + DefinedType.Scalar) + else: + result += "auto *{} = __{}.ptr<{}>();".format( + local_name, local_name, memlet.veclen) + self._dispatcher.defined_vars.add(local_name, + DefinedType.Pointer) + elif (var_type == DefinedType.Stream + or var_type == DefinedType.StreamArray): + if memlet.num_accesses == 1: + if output: + result += "{} {};".format(memlet_type, local_name) + else: + result += "auto {} = __{}.pop();".format( + local_name, local_name) + self._dispatcher.defined_vars.add(local_name, + DefinedType.Scalar) + else: + # Just forward actions to the underlying object + result += "auto &{} = __{};".format(local_name, local_name) + self._dispatcher.defined_vars.add(local_name, + DefinedType.Stream) + else: + raise TypeError("Unknown variable type: {}".format(var_type)) + + return result + + def memlet_stream_ctor(self, sdfg, memlet): + stream = sdfg.arrays[memlet.data] + dtype = "dace::vec<{}, {}>".format(stream.dtype.ctype, + symbolic.symstr(memlet.veclen)) + return "dace::make_streamview({})".format(memlet.data + ( + "[{}]".format(cpp_offset_expr(stream, memlet.subset)) + if isinstance(stream, dace.data.Stream) + and stream.is_stream_array() else "")) + + def memlet_ctor(self, sdfg, memlet, is_output): + + def_type = self._dispatcher.defined_vars.get(memlet.data) + + if (def_type == DefinedType.Stream + or def_type == DefinedType.StreamArray): + return self.memlet_stream_ctor(sdfg, memlet) + + elif (def_type == DefinedType.Pointer or def_type == DefinedType.Scalar + or def_type == DefinedType.ScalarView + or def_type == DefinedType.ArrayView): + return self.memlet_view_ctor(sdfg, memlet, is_output) + + else: + raise NotImplementedError( + "Connector type {} not yet implemented".format(def_type)) + + def copy_expr(self, + sdfg, + dataname, + memlet, + offset=None, + relative_offset=True, + packed_types=False): + datadesc = sdfg.arrays[dataname] + if relative_offset: + s = memlet.subset + o = offset + else: + if offset is None: + s = None + elif not isinstance(offset, subsets.Subset): + s = subsets.Indices(offset) + else: + s = offset + o = None + if s != None: + offset_cppstr = cpp_offset_expr( + datadesc, s, o, memlet.veclen if packed_types else 1) + else: + offset_cppstr = '0' + dt = '' + + if memlet.veclen != 1 and not packed_types: + offset_cppstr = '(%s) / %s' % (offset_cppstr, sym2cpp( + memlet.veclen)) + dt = '(dace::vec<%s, %s> *)' % (datadesc.dtype.ctype, + sym2cpp(memlet.veclen)) + + expr = dataname + + def_type = self._dispatcher.defined_vars.get(dataname) + + add_offset = (offset_cppstr != "0") + + if def_type == DefinedType.Pointer: + return "{}{}{}".format( + dt, expr, " + {}".format(offset_cppstr) if add_offset else "") + + elif def_type == DefinedType.ArrayView: + return "{}{}.ptr(){}".format( + dt, expr, " + {}".format(offset_cppstr) if add_offset else "") + + elif def_type == DefinedType.StreamArray: + return "{}[{}]".format(expr, offset_cppstr) + + elif (def_type == DefinedType.Scalar + or def_type == DefinedType.ScalarView + or def_type == DefinedType.Stream): + + if add_offset: + raise TypeError( + "Tried to offset address of scalar {}: {}".format( + dataname, offset_cppstr)) + + if (def_type == DefinedType.Scalar + or def_type == DefinedType.ScalarView): + return "{}&{}".format(dt, expr) + else: + return dataname + + else: + raise NotImplementedError( + "copy_expr not implemented " + "for connector type: {}".format(def_type)) + + def memlet_copy_to_absolute_strides(self, + sdfg, + memlet, + src_node, + dst_node, + packed_types=False): + # Ignore vectorization flag is a hack to accommmodate FPGA behavior, + # where the pointer type is changed to a vector type, and addresses + # thus shouldn't take vectorization into account. + copy_shape = memlet.subset.size() + copy_shape = [symbolic.overapproximate(s) for s in copy_shape] + src_nodedesc = src_node.desc(sdfg) + dst_nodedesc = dst_node.desc(sdfg) + + if memlet.data == src_node.data: + src_expr = self.copy_expr( + sdfg, src_node.data, memlet, packed_types=packed_types) + dst_expr = self.copy_expr( + sdfg, + dst_node.data, + memlet, + None, + False, + packed_types=packed_types) + if memlet.other_subset is not None: + dst_expr = self.copy_expr( + sdfg, + dst_node.data, + memlet, + memlet.other_subset, + False, + packed_types=packed_types) + dst_subset = memlet.other_subset + else: + dst_subset = subsets.Range.from_array(dst_nodedesc) + src_subset = memlet.subset + + else: + src_expr = self.copy_expr( + sdfg, + src_node.data, + memlet, + None, + False, + packed_types=packed_types) + dst_expr = self.copy_expr( + sdfg, dst_node.data, memlet, packed_types=packed_types) + if memlet.other_subset is not None: + src_expr = self.copy_expr( + sdfg, + src_node.data, + memlet, + memlet.other_subset, + False, + packed_types=packed_types) + src_subset = memlet.other_subset + else: + src_subset = subsets.Range.from_array(src_nodedesc) + dst_subset = memlet.subset + + src_strides = src_subset.absolute_strides(src_nodedesc.strides) + dst_strides = dst_subset.absolute_strides(dst_nodedesc.strides) + + # Try to turn into degenerate/strided ND copies + result = ndcopy_to_strided_copy(copy_shape, src_nodedesc.strides, + src_strides, dst_nodedesc.strides, + dst_strides, memlet.subset) + if result is not None: + copy_shape, src_strides, dst_strides = result + else: + # If other_subset is defined, reduce its dimensionality by + # removing the "empty" dimensions (size = 1) and filter the + # corresponding strides out + src_strides = [ + stride for stride, s in zip(src_strides, src_subset.size()) + if s != 1 + ] + src_strides[len(src_subset):] # Include tiles + if not src_strides: + src_strides = [1] + dst_strides = [ + stride for stride, s in zip(dst_strides, dst_subset.size()) + if s != 1 + ] + dst_strides[len(dst_subset):] # Include tiles + if not dst_strides: + dst_strides = [1] + copy_shape = [s for s in copy_shape if s != 1] + if not copy_shape: + copy_shape = [1] + + # Extend copy shape to the largest among the data dimensions, + # and extend other array with the appropriate strides + if (len(dst_strides) != len(copy_shape) + or len(src_strides) != len(copy_shape)): + if memlet.data == src_node.data: + copy_shape, dst_strides = _reshape_strides( + src_subset, src_strides, dst_strides, copy_shape) + elif memlet.data == dst_node.data: + copy_shape, src_strides = _reshape_strides( + dst_subset, dst_strides, src_strides, copy_shape) + + if memlet.veclen != 1: + int_floor = sp.Function('int_floor') + src_strides[:-1] = [ + int_floor(s, memlet.veclen) for s in src_strides[:-1] + ] + dst_strides[:-1] = [ + int_floor(s, memlet.veclen) for s in dst_strides[:-1] + ] + if not packed_types: + copy_shape[-1] = int_floor(copy_shape[-1], memlet.veclen) + + return copy_shape, src_strides, dst_strides, src_expr, dst_expr + + ######################################################################### + # Dynamically-called node dispatchers + + def _generate_Tasklet(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + callsite_stream.write('{\n', sdfg, state_id, node) + + # Add code to init and exit functions + self._frame._initcode.write(node.code_init, sdfg) + self._frame._exitcode.write(node.code_exit, sdfg) + + state_dfg = sdfg.nodes()[state_id] + + self._dispatcher.defined_vars.enter_scope(node) + + arrays = set() + for edge in state_dfg.in_edges(node): + u = edge.src + memlet = edge.data + + if edge.dst_conn: # Not (None or "") + if edge.dst_conn in arrays: # Disallow duplicates + raise SyntaxError('Duplicates found in memlets') + # Special case: code->code + if isinstance(edge.src, nodes.CodeNode): + shared_data_name = 's%d_n%d%s_n%d%s' % ( + state_id, dfg.node_id(edge.src), edge.src_conn, + dfg.node_id(edge.dst), edge.dst_conn) + + # Read variable from shared storage + callsite_stream.write( + 'const dace::vec<%s, %s>& %s = __%s;' % + (sdfg.arrays[memlet.data].dtype.ctype, + sym2cpp(memlet.veclen), edge.dst_conn, + shared_data_name), sdfg, state_id, + [edge.src, edge.dst]) + self._dispatcher.defined_vars.add(edge.dst_conn, + DefinedType.Scalar) + + else: + src_node = find_input_arraynode(state_dfg, edge) + + self._dispatcher.dispatch_copy( + src_node, node, edge, sdfg, state_dfg, state_id, + function_stream, callsite_stream) + + # Also define variables in the C++ unparser scope + self._locals.define(edge.dst_conn, -1, self._ldepth + 1) + arrays.add(edge.dst_conn) + + callsite_stream.write('\n', sdfg, state_id, node) + + # Use outgoing edges to preallocate output local vars + for edge in state_dfg.out_edges(node): + v = edge.dst + memlet = edge.data + + if edge.src_conn: + if edge.src_conn in arrays: # Disallow duplicates + continue + # Special case: code->code + if isinstance(edge.dst, nodes.CodeNode): + callsite_stream.write( + 'dace::vec<%s, %s> %s;' % + (sdfg.arrays[memlet.data].dtype.ctype, + sym2cpp(memlet.veclen), edge.src_conn), sdfg, + state_id, [edge.src, edge.dst]) + self._dispatcher.defined_vars.add(edge.src_conn, + DefinedType.Scalar) + else: + dst_node = find_output_arraynode(state_dfg, edge) + + self._dispatcher.dispatch_copy( + node, dst_node, edge, sdfg, state_dfg, state_id, + function_stream, callsite_stream) + + # Also define variables in the C++ unparser scope + self._locals.define(edge.src_conn, -1, self._ldepth + 1) + arrays.add(edge.src_conn) + + callsite_stream.write('\n ///////////////////\n', sdfg, state_id, + node) + + unparse_tasklet(sdfg, state_id, dfg, node, function_stream, + callsite_stream, self._locals, self._ldepth) + + callsite_stream.write(' ///////////////////\n\n', sdfg, state_id, + node) + + # Process outgoing memlets + self.process_out_memlets(sdfg, state_id, node, state_dfg, + self._dispatcher, callsite_stream, True, + function_stream) + + ############################################################# + # Instrumentation: Post-tasklet + if PerfSettings.perf_enable_instrumentation( + ) and PerfUtils.has_surrounding_perfcounters(node, dfg): + # Add bytes moved + callsite_stream.write( + "__perf_store.addBytesMoved(%s);" % + PerfUtils.get_tasklet_byte_accesses(node, dfg, sdfg, state_id)) + ############################################################# + + callsite_stream.write('}\n', sdfg, state_id, node) + + self._dispatcher.defined_vars.exit_scope(node) + + def _generate_EmptyTasklet(self, sdfg, dfg, state_id, node, + function_stream, callsite_stream): + self._generate_Tasklet(sdfg, dfg, state_id, node, function_stream, + callsite_stream) + + def _generate_NestedSDFG(self, sdfg, dfg: ScopeSubgraphView, state_id, + node, function_stream: CodeIOStream, + callsite_stream: CodeIOStream): + + self._dispatcher.defined_vars.enter_scope(sdfg) + + # If SDFG parent is not set, set it + node.sdfg._parent = sdfg + state_dfg = sdfg.nodes()[state_id] + + # Take care of nested SDFG I/O + for _, _, _, vconn, in_memlet in state_dfg.in_edges(node): + callsite_stream.write( + self.memlet_definition(sdfg, in_memlet, False, vconn), sdfg, + state_id, node) + for _, uconn, _, _, out_memlet in state_dfg.out_edges(node): + callsite_stream.write( + self.memlet_definition(sdfg, out_memlet, True, uconn), sdfg, + state_id, node) + + callsite_stream.write('\n ///////////////////\n', sdfg, state_id, + node) + + sdfg_label = '_%d_%d' % (state_id, dfg.node_id(node)) + # Generate code for internal SDFG + global_code, local_code, used_targets = \ + self._frame.generate_code(node.sdfg, node.schedule, sdfg_label) + + # Write generated code in the proper places (nested SDFG writes + # location info) + function_stream.write(global_code) + callsite_stream.write(local_code) + + callsite_stream.write(' ///////////////////\n\n', sdfg, state_id, + node) + + # Process outgoing memlets with the internal SDFG + self.process_out_memlets(sdfg, state_id, node, state_dfg, + self._dispatcher, callsite_stream, True, + function_stream) + + self._dispatcher.defined_vars.exit_scope(sdfg) + + def _generate_MapEntry(self, sdfg, dfg, state_id, node: nodes.MapEntry, + function_stream, callsite_stream): + map_params = node.map.params + map_name = '__DACEMAP_' + str(state_id) + '_' + str(dfg.node_id(node)) + + unified_id = PerfUtils.unified_id(dfg.node_id(node), state_id) + + ############################################################# + # Instrumentation: Pre-MapEntry + + # Intrusively set the depth + PerfUtils.set_map_depth(node, dfg) + + result = callsite_stream + + map_header = '' + + if PerfSettings.perf_enable_instrumentation(): + idstr = "// (Node %d)\n" % unified_id + map_header += idstr # Used to identify line numbers later + PerfMetaInfoStatic.info.add_node(node, idstr) + + if node.map.schedule == types.ScheduleType.CPU_Multicore: + # We have to find out if we should mark a section start here or later. + children = PerfUtils.all_maps(node, dfg) + + for x in children: + if PerfUtils.map_depth( + x) > PerfSettings.perf_max_scope_depth(): + break # We have our relevant nodes. + if x.map.schedule == types.ScheduleType.CPU_Multicore: + # nested SuperSections are not well-supported + # We have to mark the outermost section, + # which also means that we have to somehow tell the + # lower nodes to not mark the section start. + x.map._can_be_supersection_start = False + + if PerfSettings.perf_enable_instrumentation_for( + sdfg, node + ) and PerfUtils.map_depth( + node + ) <= PerfSettings.perf_max_scope_depth( + ) and node.map._can_be_supersection_start and not dfg.is_parallel( + ): + map_header += "__perf_store.markSuperSectionStart(%d);\n" % unified_id + elif PerfSettings.perf_supersection_emission_debug(): + reasons = [] + if not node.map._can_be_supersection_start: + reasons.append("CANNOT_BE_SS") + if dfg.is_parallel(): + reasons.append("CONTAINER_IS_PARALLEL") + if PerfUtils.map_depth( + node) > PerfSettings.perf_max_scope_depth(): + reasons.append("EXCEED_MAX_DEPTH") + if not PerfSettings.perf_enable_instrumentation_for( + sdfg, node): + reasons.append("MISC") + + map_header += "// SuperSection start not emitted. Reasons: " + ",".join( + reasons) + "\n" + + elif PerfSettings.perf_enable_instrumentation_for( + sdfg, node + ) and PerfUtils.map_depth(node) == PerfSettings.perf_max_scope_depth( + ) and node.map._can_be_supersection_start and not dfg.is_parallel(): + # even if the schedule is sequential, we can serialize to + # keep buffer usage low + map_header += "__perf_store.markSuperSectionStart(%d);\n" % unified_id + + if PerfUtils.instrument_entry( + node, dfg) and PerfSettings.perf_enable_instrumentation_for( + sdfg, node): + + size = PerfUtils.accumulate_byte_movements_v2( + node, node, dfg, sdfg, state_id) + size = sp.simplify(size) + + used_symbols = symbolic.symbols_in_sympy_expr(size) + defined_symbols = sdfg.symbols_defined_at(node) + undefined_symbols = [ + x for x in used_symbols if x not in defined_symbols + ] + if len(undefined_symbols) > 0: + # We cannot statically determine the size at this point + print( + "Failed to determine size because of undefined symbols (\"" + + str(undefined_symbols) + "\") in \"" + str(size) + + "\", falling back to 0") + size = 0 + + size = sym2cpp(size) + + map_header += "__perf_store.markSectionStart(%d, (long long)%s, PAPI_thread_id());\n" % ( + unified_id, size) + + ############################################################# + + if node.map.schedule == types.ScheduleType.CPU_Multicore: + map_header += '#pragma omp parallel for' + openmp_parallel_for_defined = True + + # The code below is disabled since we now use pragma omp atomic + # TODO(later): set up register outside loop + #exit_node = dfg.exit_nodes(node)[0] + reduction_stmts = [] + #for outedge in dfg.in_edges(exit_node): + # if (isinstance(outedge.src, nodes.CodeNode) + # and outedge.data.wcr is not None): + # redt = operations.detect_reduction_type(outedge.data.wcr) + # if redt != types.ReductionType.Custom: + # reduction_stmts.append('reduction({typ}:{var})'.format( + # typ=_REDUCTION_TYPE_TO_OPENMP[redt], + # var=outedge.src_conn)) + # reduced_variables.append(outedge) + + map_header += ' %s\n' % ', '.join(reduction_stmts) + + # TODO: Explicit map unroller + if node.map.unroll: + if node.map.schedule == types.ScheduleType.CPU_Multicore: + raise ValueError('An Multicore CPU map cannot be unrolled (' + + node.map.label + ')') + + constsize = all([ + not symbolic.issymbolic(v, sdfg.constants) for r in node.map.range + for v in r + ]) + + # Construct (EXCLUSIVE) map range as a list of comma-delimited C++ + # strings. + maprange_cppstr = [ + '%s, %s, %s' % (sym2cpp(rb), sym2cpp(re + 1), sym2cpp(rs)) + for rb, re, rs in node.map.range + ] + + # Map flattening + if node.map.flatten: + + ############################################################# + # Instrumentation: Post-MapEntry (pre-definitions) + perf_entry_string = ( + 'dace_perf::%s __perf_%d;\n' + + 'auto& __vs_%d = __perf_store.getNewValueSet(__perf_%d, %d, PAPI_thread_id(), %%s);\n' + + '__perf_%d.enterCritical();\n') % ( + PerfUtils.perf_counter_string(node), unified_id, + unified_id, unified_id, unified_id, unified_id) + ############################################################# + + # If the integer set is constant-sized, emit const_int_range + if constsize: + # Generate the loop + result.write( + """ +typedef dace::const_int_range<{range}> {mapname}_rng; +{map_header} +for (int {mapname}_iter = 0; {mapname}_iter < {mapname}_rng::size; ++{mapname}_iter) {{ + """.format( + range=', '.join(maprange_cppstr), + map_header=map_header, + mapname=map_name), sdfg, state_id, node) + + ############################################################# + # Instrumentation: Post-MapEntry (pre-definitions) + # Perfcounters for flattened maps include the calculations + # made to obtain the different axis indices + if PerfUtils.instrument_entry( + node, + dfg) and PerfSettings.perf_enable_instrumentation_for( + sdfg, node): + result.write(perf_entry_string % (map_name + "_iter"), + sdfg, state_id, node) + # remember which map has the counters enabled + node.map._has_papi_counters = True + ############################################################# + + # Generate the variables + for ind, var in enumerate(map_params): + result.write( + ('auto {var} = {mapname}_rng' + + '::index_value({mapname}_iter, ' + '{ind});').format( + ind=ind, var=var, + mapname=map_name), sdfg, state_id, node) + else: # Runtime-size integer range set + # Generate the loop + result.write( + """ +auto {mapname}_rng = dace::make_range({tuplerange}); +{map_header} +for (int {mapname}_iter = 0; {mapname}_iter < {mapname}_rng.size(); ++{mapname}_iter) {{ + """.format( + tuplerange=', '.join([ + 'std::make_tuple(%s)' % cppr + for cppr in maprange_cppstr + ]), + map_header=map_header, + mapname=map_name), sdfg, state_id, node) + + ############################################################# + # Instrumentation: Post-MapEntry (pre-definitions) + # Perfcounters for flattened maps include the calculations + # made to obtain the different axis indices + if PerfUtils.instrument_entry( + node, + dfg) and PerfSettings.perf_enable_instrumentation_for( + sdfg, node): + result.write(perf_entry_string % (map_name + "_iter"), + sdfg, state_id, node) + # remember which map has the counters enabled + node.map._has_papi_counters = True + ############################################################# + + # Generate the variables + for ind, var in enumerate(map_params): + result.write( + ('auto {var} = {mapname}_rng' + + '.index_value({mapname}_iter, ' + '{ind});').format( + ind=ind, var=var, + mapname=map_name), sdfg, state_id, node) + + else: # Nested loops + result.write(map_header, sdfg, state_id, node) + for i, r in enumerate(node.map.range): + #var = '__DACEMAP_%s_%d' % (node.map.label, i) + var = map_params[i] + begin, end, skip = r + + if node.map.unroll: + result.write('#pragma unroll', sdfg, state_id, node) + + result.write( + 'for (auto %s = %s; %s < %s; %s += %s) {\n' % + (var, sym2cpp(begin), var, sym2cpp(end + 1), var, + sym2cpp(skip)), sdfg, state_id, node) + + ############################################################# + # Instrumentation: Post-MapEntry (pre-definitions) + if PerfUtils.instrument_entry(node, dfg) and ( + (not PerfSettings.perf_debug_profile_innermost and i == 0) + or (PerfSettings.perf_debug_profile_innermost + and i == len(node.map.range) - 1) + ) and PerfSettings.perf_enable_instrumentation_for(sdfg, node): + result.write( + ('dace_perf::%s __perf_%d;\n' + + 'auto& __vs_%d = __perf_store.getNewValueSet(__perf_%d, %d, PAPI_thread_id(), %s);\n' + + '__perf_%d.enterCritical();\n') % + (PerfUtils.perf_counter_string(node), unified_id, + unified_id, unified_id, unified_id, var, unified_id), + sdfg, state_id, node) + # remember which map has the counters enabled + node.map._has_papi_counters = True + ############################################################# + + # Emit internal transient array allocation + to_allocate = dace.sdfg.local_transients(sdfg, dfg, node) + allocated = set() + for child in dfg.scope_dict(node_to_children=True)[node]: + if not isinstance(child, nodes.AccessNode): + continue + if child.data not in to_allocate or child.data in allocated: + continue + allocated.add(child.data) + self._dispatcher.dispatch_allocate(sdfg, dfg, state_id, child, + None, result) + self._dispatcher.dispatch_initialize(sdfg, dfg, state_id, child, + None, result) + + # Generate register definitions for inter-tasklet memlets + scope_dict = dfg.scope_dict() + for edge in dfg.edges(): + # Only interested in edges within current scope + if scope_dict[edge.src] != node or scope_dict[edge.dst] != node: + continue + if (isinstance(edge.src, nodes.CodeNode) + and isinstance(edge.dst, nodes.CodeNode)): + local_name = '__s%d_n%d%s_n%d%s' % ( + state_id, dfg.node_id(edge.src), edge.src_conn, + dfg.node_id(edge.dst), edge.dst_conn) + # Allocate variable type + code = 'dace::vec<%s, %s> %s;' % ( + sdfg.arrays[edge.data.data].dtype.ctype, + sym2cpp(edge.data.veclen), local_name) + result.write(code, sdfg, state_id, [edge.src, edge.dst]) + + def _generate_MapExit(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + result = callsite_stream + + # Obtain start of map + scope_dict = dfg.scope_dict() + map_node = scope_dict[node] + + if map_node is None: + raise ValueError('Exit node ' + str(node.map.label) + + ' is not dominated by a scope entry node') + + ############################################################# + # Instrumentation: Pre-MapExit + unified_id = PerfUtils.unified_id(dfg.node_id(map_node), state_id) + ############################################################# + + # Emit internal transient array deallocation + to_allocate = dace.sdfg.local_transients(sdfg, dfg, map_node) + deallocated = set() + for child in dfg.scope_dict(node_to_children=True)[map_node]: + if not isinstance(child, nodes.AccessNode): + continue + if child.data not in to_allocate or child.data in deallocated: + continue + deallocated.add(child.data) + self._dispatcher.dispatch_deallocate(sdfg, dfg, state_id, child, + None, result) + + # If there are other non-visited map exits, they are responsible for + # closing braces + map_exits = [ + k for k, v in scope_dict.items() + if v == map_node and isinstance(k, nodes.ExitNode) + and k not in self._generated_nodes + ] + if len(map_exits) > 1: + return + + # Map flattening + if map_node.map.flatten: + ############################################################# + # Instrumentation: Pre-MapExit + if PerfSettings.perf_enable_instrumentation( + ) and map_node.map._has_papi_counters: + result.write( + '__perf_%d.leaveCritical(__vs_%d);\n' % + (unified_id, unified_id), sdfg, state_id, node) + if PerfSettings.perf_debug_annotate_scopes: + result.write('// %s\n' % str(map_node), sdfg, state_id, node) + ############################################################# + result.write('}', sdfg, state_id, node) + else: + for i, r in enumerate(map_node.map.range): + ############################################################# + # Instrumentation: Pre-MapExit + if PerfSettings.perf_enable_instrumentation( + ) and map_node.map._has_papi_counters and ( + (PerfSettings.perf_debug_profile_innermost and i == 0) or + (not PerfSettings.perf_debug_profile_innermost + and i == len(map_node.map.range) - 1)): + result.write( + '__perf_%d.leaveCritical(__vs_%d);\n' % + (unified_id, unified_id), sdfg, state_id, node) + if PerfSettings.perf_debug_annotate_scopes and i == len( + map_node.map.range) - 1: + result.write('// %s\n' % str(map_node), sdfg, state_id, + node) + ############################################################# + result.write('}', sdfg, state_id, node) + + ############################################################# + # Instrumentation: Post-MapExit + if PerfSettings.perf_enable_vectorization_analysis(): + idstr = "// end (Node %d)\n" % unified_id + result.write(idstr, sdfg, state_id, node) + PerfMetaInfoStatic.info.add_node(node, idstr) + ############################################################# + + def _generate_ConsumeEntry(self, sdfg, dfg, state_id, node: nodes.MapEntry, + function_stream, callsite_stream): + result = callsite_stream + + constsize = all([ + not symbolic.issymbolic(v, sdfg.constants) for r in node.map.range + for v in r + ]) + state_dfg = sdfg.nodes()[state_id] + + input_sedge = next( + e for e in state_dfg.in_edges(node) if e.dst_conn == 'IN_stream') + output_sedge = next( + e for e in state_dfg.out_edges(node) if e.src_conn == 'OUT_stream') + input_stream = state_dfg.memlet_path(input_sedge)[0].src + input_streamdesc = input_stream.desc(sdfg) + + # Take chunks into account + if node.consume.chunksize == 1: + chunk = 'const %s& %s' % (input_streamdesc.dtype.ctype, + node.consume.label + '_element') + self._dispatcher.defined_vars.add(node.consume.label + "_element", + DefinedType.Scalar) + else: + chunk = 'const %s *%s, size_t %s' % ( + input_streamdesc.dtype.ctype, node.consume.label + '_elements', + node.consume.label + '_numelems') + self._dispatcher.defined_vars.add(node.consume.label + "_elements", + DefinedType.Pointer) + self._dispatcher.defined_vars.add(node.consume.label + "_numelems", + DefinedType.Scalar) + + # Take quiescence condition into account + if node.consume.condition is not None: + condition_string = ( + '[&]() { return %s; }, ' % cppunparse.cppunparse( + node.consume.condition, False)) + else: + condition_string = '' + + result.write( + 'dace::Consume<{chunksz}>::template consume{cond}({stream_in}, ' + '{num_pes}, {condition}' + '[&](int {pe_index}, {element_or_chunk}) {{'.format( + chunksz=node.consume.chunksize, + cond='' if node.consume.condition is None else '_cond', + condition=condition_string, + stream_in=input_stream.data, # TODO: stream arrays + element_or_chunk=chunk, + num_pes=sym2cpp(node.consume.num_pes), + pe_index=node.consume.pe_index), + sdfg, + state_id, + node) + + # Since consume is an alias node, we create an actual array for the + # consumed element and modify the outgoing memlet path ("OUT_stream") + # TODO: do this before getting to the codegen + if node.consume.chunksize == 1: + consumed_element = sdfg.add_scalar( + node.consume.label + '_element', + input_streamdesc.dtype, + transient=True, + storage=types.StorageType.Register) + ce_node = nodes.AccessNode(node.consume.label + '_element', + types.AccessType.ReadOnly) + else: + consumed_element = sdfg.add_array( + node.consume.label + '_elements', [node.consume.chunksize], + input_streamdesc.dtype, + transient=True, + storage=types.StorageType.Register) + ce_node = nodes.AccessNode(node.consume.label + '_elements', + types.AccessType.ReadOnly) + state_dfg.add_node(ce_node) + out_memlet_path = state_dfg.memlet_path(output_sedge) + state_dfg.remove_edge(out_memlet_path[0]) + state_dfg.add_edge( + out_memlet_path[0].src, out_memlet_path[0].src_conn, ce_node, None, + mmlt.Memlet.from_array(ce_node.data, ce_node.desc(sdfg))) + state_dfg.add_edge( + ce_node, None, out_memlet_path[0].dst, out_memlet_path[0].dst_conn, + mmlt.Memlet.from_array(ce_node.data, ce_node.desc(sdfg))) + for e in out_memlet_path[1:]: + e.data.data = ce_node.data + ## END of SDFG-rewriting code + + # Emit internal transient array allocation + to_allocate = dace.sdfg.local_transients(sdfg, dfg, node) + allocated = set() + for child in dfg.scope_dict(node_to_children=True)[node]: + if not isinstance(child, nodes.AccessNode): + continue + if child.data not in to_allocate or child.data in allocated: + continue + allocated.add(child.data) + self._dispatcher.dispatch_allocate(sdfg, dfg, state_id, child, + None, result) + self._dispatcher.dispatch_initialize(sdfg, dfg, state_id, child, + None, result) + + # Generate register definitions for inter-tasklet memlets + scope_dict = dfg.scope_dict() + for edge in dfg.edges(): + # Only interested in edges within current scope + if scope_dict[edge.src] != node or scope_dict[edge.dst] != node: + continue + if (isinstance(edge.src, nodes.CodeNode) + and isinstance(edge.dst, nodes.CodeNode)): + local_name = '__s%d_n%d%s_n%d%s' % ( + state_id, dfg.node_id(edge.src), edge.src_conn, + dfg.node_id(edge.dst), edge.dst_conn) + # Allocate variable type + code = 'dace::vec<%s, %s> %s;' % ( + sdfg.arrays[edge.data.data].dtype.ctype, + sym2cpp(edge.data.veclen), local_name) + result.write(code, sdfg, state_id, [edge.src, edge.dst]) + + def _generate_ConsumeExit(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + result = callsite_stream + + # Obtain start of map + scope_dict = dfg.scope_dict() + entry_node = scope_dict[node] + + if entry_node is None: + raise ValueError('Exit node ' + str(node.consume.label) + + ' is not dominated by a scope entry node') + + # Emit internal transient array deallocation + to_allocate = dace.sdfg.local_transients(sdfg, dfg, entry_node) + deallocated = set() + for child in dfg.scope_dict(node_to_children=True)[entry_node]: + if not isinstance(child, nodes.AccessNode): + continue + if child.data not in to_allocate or child.data in deallocated: + continue + deallocated.add(child.data) + self._dispatcher.dispatch_deallocate(sdfg, dfg, state_id, child, + None, result) + + result.write('});', sdfg, state_id, node) + + def _generate_Reduce(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + + unified_id = PerfUtils.unified_id(dfg.node_id(node), state_id) + + # Try to autodetect reduction type + redtype = operations.detect_reduction_type(node.wcr) + + loop_header = '' + + perf_should_instrument = PerfSettings.perf_enable_instrumentation( + ) and not PerfUtils.has_surrounding_perfcounters( + node, dfg) and PerfSettings.perf_enable_instrumentation_for( + sdfg, node) + + if node.schedule == types.ScheduleType.CPU_Multicore: + if PerfSettings.perf_enable_vectorization_analysis(): + idstr = "// (Node %d)\n" % dfg.node_id(node) + loop_header += idstr + PerfMetaInfoStatic.info.add_node(node, idstr) + loop_header += '#pragma omp parallel for' + + end_braces = 0 + + axes = node.axes + state_dfg = sdfg.nodes()[state_id] + input_memlet = state_dfg.in_edges(node)[0].data + output_edge = state_dfg.out_edges(node)[0] + output_memlet = output_edge.data + + output_type = 'dace::vec<%s, %s>' % ( + sdfg.arrays[output_memlet.data].dtype.ctype, output_memlet.veclen) + + # If axes were not defined, use all input dimensions + input_dims = input_memlet.subset.dims() + output_dims = output_memlet.subset.data_dims() + if axes is None: + axes = tuple(range(input_dims)) + + # Obtain variable names per output and reduction axis + axis_vars = [] + octr = 0 + for d in range(input_dims): + if d in axes: + axis_vars.append('__i%d' % d) + else: + axis_vars.append('__o%d' % octr) + octr += 1 + + ############################################################# + # Instrumentation: Pre-reduce + # For measuring the memory bandwidth, we analyze the amount of data + # moved. + if perf_should_instrument: + perf_expected_data_movement_sympy = 1 + + for axis in range(output_dims): + ao = output_memlet.subset[axis] + perf_expected_data_movement_sympy *= ( + (ao[1] + 1 - ao[0]) / ao[2]) + + for axis in axes: + ai = input_memlet.subset[axis] + perf_expected_data_movement_sympy *= ( + (ai[1] + 1 - ai[0]) / ai[2]) + + if not dfg.is_parallel(): + # Now we put a start marker, but only if we are in a serial state + callsite_stream.write( + '__perf_store.markSuperSectionStart(%d);\n' % (unified_id)) + + callsite_stream.write( + '__perf_store.markSectionStart(%d, (long long)%s, PAPI_thread_id());\n' + % (unified_id, + str(sp.simplify(perf_expected_data_movement_sympy)) + + (" * (sizeof(%s) + sizeof(%s))" % + (sdfg.arrays[output_memlet.data].dtype.ctype, + sdfg.arrays[input_memlet.data].dtype.ctype))), sdfg, + state_id, node) + ############################################################# + + # Write OpenMP loop pragma if there are output dimensions + if output_dims > 0: + callsite_stream.write(loop_header, sdfg, state_id, node) + + # Generate outer loops + output_subset = output_memlet.subset + for axis in range(output_dims): + callsite_stream.write( + 'for (int {var} = {begin}; {var} < {end}; {var} += {skip}) {{'. + format( + var='__o%d' % axis, + begin=output_subset[axis][0], + end=output_subset[axis][1] + 1, + skip=output_subset[axis][2]), sdfg, state_id, node) + + ############################################################# + # Instrumentation: Reduce (part 1) + # This could prevent the compiler from parallelizing/vectorizing + if perf_should_instrument: + if ((end_braces == 0 + and not PerfSettings.perf_debug_profile_innermost) + or (end_braces == output_dims - 1 + and PerfSettings.perf_debug_profile_innermost)): + callsite_stream.write( + 'dace_perf::%s __perf_%d;\n' % + (PerfUtils.perf_counter_string(node), unified_id), + sdfg, state_id, node) + callsite_stream.write( + 'auto& __perf_%d_vs = __perf_store.getNewValueSet(__perf_%d, %d, PAPI_thread_id(), __o%d);\n' + % (unified_id, unified_id, unified_id, axis), sdfg, + state_id, node) + callsite_stream.write( + '__perf_%d.enterCritical();\n' % unified_id, sdfg, + state_id, node) + ############################################################# + end_braces += 1 + + ############################################################# + # Instrumentation: Reduce (part 2) + if end_braces == 0 and perf_should_instrument: + callsite_stream.write( + 'dace_perf::%s __perf_%d;\n' % + (PerfUtils.perf_counter_string(node), unified_id), sdfg, + state_id, node) + callsite_stream.write( + 'auto& __perf_%d_vs = __perf_store.getNewValueSet(__perf_%d, %d, PAPI_thread_id(), 0);\n' + % (unified_id, unified_id, unified_id), sdfg, state_id, node) + callsite_stream.write('__perf_%d.enterCritical();\n' % unified_id, + sdfg, state_id, node) + ############################################################# + + use_tmpout = False + if len(axes) == input_dims: + # Add OpenMP reduction clause if reducing all axes + if (redtype != types.ReductionType.Custom + and node.schedule == types.ScheduleType.CPU_Multicore): + loop_header += ' reduction(%s: __tmpout)' % ( + _REDUCTION_TYPE_TO_OPENMP[redtype]) + + # Output initialization + identity = '' + if node.identity is not None: + identity = ' = %s' % sym2cpp(node.identity) + callsite_stream.write( + '{\n%s __tmpout%s;' % (output_type, identity), sdfg, state_id, + node) + callsite_stream.write(loop_header, sdfg, state_id, node) + end_braces += 1 + use_tmpout = True + + # Generate inner loops (reducing) + input_subset = input_memlet.subset + for axis in axes: + callsite_stream.write( + 'for (int {var} = {begin}; {var} < {end}; {var} += {skip}) {{'. + format( + var='__i%d' % axis, + begin=input_subset[axis][0], + end=input_subset[axis][1] + 1, + skip=input_subset[axis][2]), sdfg, state_id, node) + end_braces += 1 + + # Generate reduction code + credtype = 'dace::ReductionType::' + str( + redtype)[str(redtype).find('.') + 1:] + + # Use index expressions + outvar = ('__tmpout' if use_tmpout else cpp_array_expr( + sdfg, + output_memlet, + offset=['__o%d' % i for i in range(output_dims)], + relative_offset=False)) + invar = cpp_array_expr( + sdfg, input_memlet, offset=axis_vars, relative_offset=False) + + if redtype != types.ReductionType.Custom: + callsite_stream.write( + 'dace::wcr_fixed<%s, %s>::reduce_atomic(&%s, %s);' % + (credtype, output_type, outvar, invar), sdfg, state_id, + node) #cpp_array_expr(), cpp_array_expr() + else: + callsite_stream.write( + 'dace::wcr_custom<%s>::template reduce_atomic(%s, &%s, %s);' % + (output_type, unparse_cr(node.wcr), outvar, invar), sdfg, + state_id, node) #cpp_array_expr(), cpp_array_expr() + + ############################################################# + # Instrumentation: Post-Reduce (pre-braces) + byte_moved_measurement = "__perf_store.addBytesMoved(%s);\n" + + # For reductions, we assume Read-Modify-Write for all operations + # Every reduction statement costs sizeof(input) + sizeof(output). + # This is wrong with some custom reductions or extending operations + # (e.g., i32 * i32 => i64) + # It also is wrong for write-avoiding min/max (min/max that only + # overwrite the reduced variable when it needs to be changed) + + if perf_should_instrument: + callsite_stream.write( + byte_moved_measurement % ("(sizeof(%s) + sizeof(%s))" % + (outvar, invar)), sdfg, state_id, + node) + ############################################################# + + # Generate closing braces + for i in range(end_braces): + # Store back tmpout into the true output + if i == end_braces - 1 and use_tmpout: + callsite_stream.write( + '%s = __tmpout;' % cpp_array_expr(sdfg, output_memlet), + sdfg, state_id, node) + ############################################################# + # Instrumentation: Post-Reduce (in-braces) + if perf_should_instrument and ( + (i == end_braces - 1 + and not PerfSettings.perf_debug_profile_innermost) or + (i == len(axes) + and PerfSettings.perf_debug_profile_innermost)): + callsite_stream.write( + '__perf_%d.leaveCritical(__perf_%d_vs);\n' % + (unified_id, unified_id), sdfg, state_id, node) + ############################################################# + callsite_stream.write('}', sdfg, state_id, node) + + def _generate_AccessNode(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + state_dfg = sdfg.nodes()[state_id] + + if node not in state_dfg.sink_nodes(): + # NOTE: sink nodes are synchronized at the end of a state + presynchronize_streams(sdfg, state_dfg, state_id, node, + callsite_stream) + + sdict = state_dfg.scope_dict() + for edge in state_dfg.in_edges(node): + predecessor, _, _, _, memlet = edge + if memlet.data is None: + continue # If the edge has to be skipped + + # Determines if this path ends here or has a definite source (array) node + memlet_path = state_dfg.memlet_path(edge) + if memlet_path[-1].dst == node: + src_node = memlet_path[0].src + # Only generate code in case this is the innermost scope + # (copies are generated at the inner scope, where both arrays exist) + if (scope_contains_scope(sdict, src_node, node) + and sdict[src_node] != sdict[node]): + self._dispatcher.dispatch_copy( + src_node, node, edge, sdfg, dfg, state_id, + function_stream, callsite_stream) + + # Process outgoing memlets (array-to-array write should be emitted + # from the first leading edge out of the array) + self.process_out_memlets(sdfg, state_id, node, state_dfg, + self._dispatcher, callsite_stream, False, + function_stream) + + +######################################################################## +######################################################################## +######################################################################## +######################################################################## +# Helper functions and classes + + +def _reshape_strides(subset, strides, original_strides, copy_shape): + """ Helper function that reshapes a shape to the given strides. """ + # TODO(later): Address original strides in the computation of the + # result strides. + original_copy_shape = subset.size() + dims = len(copy_shape) + + reduced_tile_sizes = [ + ts for ts, s in zip(subset.tile_sizes, original_copy_shape) if s != 1 + ] + + reshaped_copy = copy_shape + [ts for ts in subset.tile_sizes if ts != 1] + reshaped_copy[:len(copy_shape)] = [ + s / ts for s, ts in zip(copy_shape, reduced_tile_sizes) + ] + + new_strides = [0] * len(reshaped_copy) + elements_remaining = functools.reduce(sp.mul.Mul, copy_shape, 1) + tiledim = 0 + for i in range(len(copy_shape)): + new_strides[i] = elements_remaining / reshaped_copy[i] + elements_remaining = new_strides[i] + if reduced_tile_sizes[i] != 1: + new_strides[dims + tiledim] = ( + elements_remaining / reshaped_copy[dims + tiledim]) + elements_remaining = new_strides[dims + tiledim] + tiledim += 1 + + return reshaped_copy, new_strides + + +def ndcopy_to_strided_copy(copy_shape, src_shape, src_strides, dst_shape, + dst_strides, subset): + """ Detects situations where an N-dimensional copy can be degenerated into + a (faster) 1D copy or 2D strided copy. Returns new copy + dimensions and offsets to emulate the requested copy. + + @return: a 3-tuple: copy_shape, src_strides, dst_strides + """ + dims = len(copy_shape) + + # Cannot degenerate tiled copies + if any(ts != 1 for ts in subset.tile_sizes): + return None + + # 1D copy of the whole array + if (tuple(copy_shape) == tuple(src_shape) + and tuple(copy_shape) == tuple(dst_shape)): + copy_shape = [functools.reduce(lambda x, y: x * y, copy_shape)] + return copy_shape, [1], [1] + # 1D strided copy + elif sum([0 if c == 1 else 1 for c in copy_shape]) == 1: + # Find the copied dimension: + # In copy shape + copydim = next(i for i, c in enumerate(copy_shape) if c != 1) + + # In source strides + if len(copy_shape) == len(src_shape): + srcdim = copydim + else: + srcdim = next(i for i, c in enumerate(src_shape) if c != 1) + + # In destination strides + if len(copy_shape) == len(dst_shape): + dstdim = copydim + else: + dstdim = next(i for i, c in enumerate(dst_shape) if c != 1) + + # Return new copy + return [copy_shape[copydim]], [src_strides[srcdim]], [ + dst_strides[dstdim] + ] + else: + return None + + +def ndslice_cpp(slice, dims, rowmajor=True): + result = StringIO() + + if len(slice) == 0: # Scalar + return '0' + + for i, d in enumerate(slice): + if isinstance(d, tuple): + raise SyntaxError( + 'CPU backend does not yet support ranges as inputs/outputs') + + # TODO(later): Use access order + + result.write(sym2cpp(d)) + + # If not last + if i < len(slice) - 1: + strdims = [str(dim) for dim in dims[i + 1:]] + result.write( + '*%s + ' % '*'.join(strdims)) # Multiply by leading dimensions + + return result.getvalue() + + +def cpp_offset_expr(d: data.Data, + subset_in: subsets.Subset, + offset=None, + packed_veclen=1): + """ Creates a C++ expression that can be added to a pointer in order + to offset it to the beginning of the given subset and offset. + @param d: The data structure to use for sizes/strides. + @param subset: The subset to offset by. + @param offset: An additional list of offsets or a Subset object + @param packed_veclen: If packed types are targeted, specifies the + vector length that the final offset should be + divided by. + @return: A string in C++ syntax with the correct offset + """ + subset = copy.deepcopy(subset_in) + + # Offset according to parameters + if offset is not None: + if isinstance(offset, subsets.Subset): + subset.offset(offset, False) + else: + subset.offset(subsets.Indices(offset), False) + + # Then, offset according to array + subset.offset(subsets.Indices(d.offset), False) + + # Obtain start range from offsetted subset + slice = [0] * len(d.strides) #subset.min_element() + + index = subset.at(slice, d.strides) + if packed_veclen > 1: + index /= packed_veclen + + return sym2cpp(index) + + +def cpp_array_expr(sdfg, + memlet, + with_brackets=True, + offset=None, + relative_offset=True, + packed_veclen=1): + """ Converts an Indices/Range object to a C++ array access string. """ + s = memlet.subset if relative_offset else subsets.Indices(offset) + o = offset if relative_offset else None + offset_cppstr = cpp_offset_expr(sdfg.arrays[memlet.data], s, o, + packed_veclen) + + if with_brackets: + return '%s[%s]' % (memlet.data, offset_cppstr) + else: + return offset_cppstr + + +def write_and_resolve_expr(memlet, nc, outname, inname, indices=None): + """ Helper function that emits a write_and_resolve call from a memlet. """ + + redtype = operations.detect_reduction_type(memlet.wcr) + + nc = '_nc' if nc else '' + indstr = (', ' + indices) if indices is not None else '' + + reduction_tmpl = '' + custom_reduction = '' + + # Special call for detected reduction types + if redtype != types.ReductionType.Custom: + credtype = ('dace::ReductionType::' + + str(redtype)[str(redtype).find('.') + 1:]) + reduction_tmpl = '<%s>' % credtype + else: + custom_reduction = ', %s' % unparse_cr(memlet.wcr) + + return '{oname}.write_and_resolve{nc}{tmpl}({iname}{wcr}{ind});'.format( + oname=outname, + nc=nc, + tmpl=reduction_tmpl, + iname=inname, + wcr=custom_reduction, + ind=indstr) + + +def is_write_conflicted(dfg, edge, datanode=None): + """ Detects whether a write-conflict-resolving edge can be emitted without + using atomics or critical sections. """ + + if edge.data.wcr_conflict is not None and not edge.data.wcr_conflict: + return False + + if edge is None: + start_node = None + memlet = None + else: + start_node = edge.dst + memlet = edge.data + + # If it's an entire SDFG, it's probably write-conflicted + if isinstance(dfg, SDFG): + if datanode is None: return True + in_edges = find_incoming_edges(datanode, dfg) + if len(in_edges) != 1: return True + if (isinstance(in_edges[0].src, nodes.ExitNode) and + in_edges[0].src.map.schedule == types.ScheduleType.Sequential): + return False + return True + + # Traverse memlet path to determine conflicts. + # If no conflicts will occur, write without atomics + # (e.g., if the array has been defined in a non-parallel schedule context) + # TODO: This is not perfect (need to take indices into consideration) + path = dfg.memlet_path(edge) + for e in path: + if (isinstance(e.dst, nodes.ExitNode) + and e.dst.map.schedule != types.ScheduleType.Sequential): + return True + # Should never happen (no such thing as write-conflicting reads) + if (isinstance(e.src, nodes.EntryNode) + and e.src.map.schedule != types.ScheduleType.Sequential): + return True + + return False + + +def unparse_cr(wcr_ast): + """ Outputs a C++ version of a conflict resolution lambda. """ + + if isinstance(wcr_ast, ast.Lambda): + return cppunparse.cppunparse(wcr_ast, expr_semicolon=False) + elif isinstance(wcr_ast, ast.FunctionDef): + # Construct a lambda function out of a function + return '[] (%s) { %s }' % ( + cppunparse.cppunparse(wcr_ast.args, expr_semicolon=False), + cppunparse.cppunparse(wcr_ast.body, expr_semicolon=False)) + elif isinstance(wcr_ast, ast.Module): + return unparse_cr(wcr_ast.body[0].value) + elif isinstance(wcr_ast, str): + return unparse_cr(LambdaProperty.from_string(wcr_ast)) + else: + raise NotImplementedError('INVALID TYPE OF WCR: ' + + type(wcr_ast).__name__) + + +def unparse_tasklet(sdfg, state_id, dfg, node, function_stream, + callsite_stream, locals, ldepth): + + if node.label is None or node.label == "": + return '' + + state_dfg = sdfg.nodes()[state_id] + unified_id = PerfUtils.unified_id(dfg.node_id(node), state_id) + + # Not [], "" or None + if not node.code: + return '' + + # Not [], "" or None + if node.code_global: + if node.language is not types.Language.CPP: + raise ValueError( + "Global code only supported for C++ tasklets: got {}".format( + node.language)) + function_stream.write( + type(node).__properties__["code_global"].to_string( + node.code_global), sdfg, state_id, node) + function_stream.write("\n", sdfg, state_id, node) + + # If raw C++ code, return the code directly + if node.language != types.Language.Python: + # If this code runs on the host and is associated with a CUDA stream, + # set the stream to a local variable. + max_streams = Config.get('compiler', 'cuda', 'max_concurrent_streams') + if (max_streams >= 0 and not is_devicelevel(sdfg, state_dfg, node) + and hasattr(node, '_cuda_stream')): + callsite_stream.write( + 'cudaStream_t __dace_current_stream = dace::cuda::__streams[%d];' + % node._cuda_stream, sdfg, state_id, node) + + if node.language != types.Language.CPP: + raise ValueError( + "Only Python or C++ code supported in CPU codegen, got: {}". + format(node.language)) + callsite_stream.write( + type(node).__properties__["code"].to_string(node.code), sdfg, + state_id, node) + + if (hasattr(node, '_cuda_stream') + and not is_devicelevel(sdfg, state_dfg, node)): + synchronize_streams(sdfg, state_dfg, state_id, node, node, + callsite_stream) + return + + body = node.code + + # Map local names to memlets (for WCR detection) + memlets = {} + for edge in state_dfg.all_edges(node): + u, uconn, v, vconn, memlet = edge + if u == node: + memlet_nc = not is_write_conflicted(dfg, edge) + memlet_wcr = memlet.wcr + + memlets[uconn] = (memlet, memlet_nc, memlet_wcr) + elif v == node: + memlets[vconn] = (memlet, False, None) + + ############################################################# + # Instrumentation: Pre-Tasklet + if PerfSettings.perf_tasklets and PerfSettings.perf_enable_instrumentation( + ): + callsite_stream.write( + 'dace_perf::%s __perf_%s;\n' % + (PerfUtils.perf_counter_string(node), node.label), sdfg, state_id, + node) + callsite_stream.write( + 'auto& __perf_vs_%s = __perf_store.getNewValueSet(__perf_%s, %d, PAPI_thread_id(), 0);\n' + % (node.label, node.label, unified_id), sdfg, state_id, node) + + callsite_stream.write('__perf_%s.enterCritical();\n' % node.label, + sdfg, state_id, node) + + ############################################################# + + callsite_stream.write('// Tasklet code (%s)\n' % node.label, sdfg, + state_id, node) + for stmt in body: + if isinstance(stmt, ast.Expr): + rk = DaCeKeywordRemover(memlets, + sdfg.constants).visit_TopLevelExpr(stmt) + else: + rk = DaCeKeywordRemover(memlets, sdfg.constants).visit(stmt) + + if rk is not None: + # Unparse to C++ and add 'auto' declarations if locals not declared + result = StringIO() + cppunparse.CPPUnparser(rk, ldepth + 1, locals, result) + callsite_stream.write(result.getvalue(), sdfg, state_id, node) + + ############################################################# + # Instrumentation: Post-Tasklet + if PerfSettings.perf_tasklets and PerfSettings.perf_enable_instrumentation( + ): + callsite_stream.write( + '__perf_%s.leaveCritical(__perf_vs_%s);' % + (node.label, node.label), sdfg, state_id, node) + ############################################################# + + +def is_array_stream_view(sdfg, dfg, node): + """ Test whether a stream is directly connected to an array. """ + + # Test all memlet paths from the array. If the path goes directly + # to/from a stream, construct a stream array view + source_paths = [] + sink_paths = [] + for e in dfg.in_edges(node): + src_node = dfg.memlet_path(e)[0].src + if (isinstance(src_node, nodes.AccessNode) + and isinstance(src_node.desc(sdfg), data.Array)): + source_paths.append(src_node) + for e in dfg.out_edges(node): + sink_node = dfg.memlet_path(e)[-1].dst + if (isinstance(sink_node, nodes.AccessNode) + and isinstance(sink_node.desc(sdfg), data.Array)): + sink_paths.append(sink_node) + + # Special case: stream can be represented as a view of an array + if len(source_paths) == 1 or len(sink_paths) == 1: + # TODO: What about a source path? + arrnode = sink_paths[0] + # Only works if the stream itself is not an array of streams + if list(node.desc(sdfg).shape) == [1]: + node.desc(sdfg).sink = arrnode.data # For memlet generation + arrnode.desc( + sdfg).src = node.data # TODO: Move src/sink to node, not array + return True + return False + + +def find_incoming_edges(node, dfg): + # If it's an entire SDFG, look in each state + if isinstance(dfg, SDFG): + result = [] + for state in dfg.nodes(): + result.extend(list(state.in_edges(node))) + return result + else: # If it's one state + return list(dfg.in_edges(node)) + + +def find_outgoing_edges(node, dfg): + # If it's an entire SDFG, look in each state + if isinstance(dfg, SDFG): + result = [] + for state in dfg.nodes(): + result.extend(list(state.out_edges(node))) + return result + else: # If it's one state + return list(dfg.out_edges(node)) + + +def sym2cpp(s): + """ Converts an array of symbolic variables (or one) to C++ strings. """ + if not isinstance(s, list): + return cppunparse.pyexpr2cpp(symbolic.symstr(s)) + return [cppunparse.pyexpr2cpp(symbolic.symstr(d)) for d in s] + + +class DaCeKeywordRemover(ExtNodeTransformer): + """ Removes memlets and other DaCe keywords from a Python AST, and + converts array accesses to C++ methods that can be generated. + + Used for unparsing Python tasklets into C++ that uses the DaCe + runtime. + + @note: Assumes that the DaCe syntax is correct (as verified by the + Python frontend). + """ + + def __init__(self, memlets, constants): + self.memlets = memlets + self.constants = constants + + def visit_TopLevelExpr(self, node): + # This is a DaCe shift, omit it + if isinstance(node.value, ast.BinOp): + if isinstance(node.value.op, ast.LShift) or isinstance( + node.value.op, ast.RShift): + return None + return self.generic_visit(node) + + def visit_AugAssign(self, node): + if not isinstance(node.target, ast.Subscript): + return self.generic_visit(node) + + target = rname(node.target) + if target not in self.memlets: + return self.generic_visit(node) + + raise SyntaxError('Augmented assignments (e.g. +=) not allowed on ' + + 'array memlets') + + def visit_Assign(self, node): + target = rname(node.targets[0]) + if target not in self.memlets: + return self.generic_visit(node) + + memlet, nc, wcr = self.memlets[target] + value = self.visit(node.value) + + if not isinstance(node.targets[0], ast.Subscript): + # Dynamic accesses -> every access counts + try: + if memlet is not None and memlet.num_accesses < 0: + if wcr is not None: + newnode = ast.Name( + id=write_and_resolve_expr( + memlet, nc, '__' + target, + cppunparse.cppunparse( + value, expr_semicolon=False))) + else: + newnode = ast.Name(id='__%s.write(%s);' % ( + target, + cppunparse.cppunparse(value, expr_semicolon=False)) + ) + + return ast.copy_location(newnode, node) + except TypeError: # cannot determine truth value of Relational + pass + + return self.generic_visit(node) + + slice = self.visit(node.targets[0].slice) + if not isinstance(slice, ast.Index): + raise NotImplementedError('Range subscripting not implemented') + + if isinstance(slice.value, ast.Tuple): + subscript = unparse(slice)[1:-1] + else: + subscript = unparse(slice) + + if wcr is not None: + newnode = ast.Name( + id=write_and_resolve_expr( + memlet, + nc, + '__' + target, + cppunparse.cppunparse(value, expr_semicolon=False), + indices=subscript)) + else: + newnode = ast.Name(id='__%s.write(%s, %s);' % ( + target, cppunparse.cppunparse(value, expr_semicolon=False), + subscript)) + + return ast.copy_location(newnode, node) + + def visit_Subscript(self, node): + target = rname(node) + if target not in self.memlets and target not in self.constants: + return self.generic_visit(node) + + slice = self.visit(node.slice) + if not isinstance(slice, ast.Index): + raise NotImplementedError('Range subscripting not implemented') + + if isinstance(slice.value, ast.Tuple): + subscript = unparse(slice)[1:-1] + else: + subscript = unparse(slice) + + if target in self.constants: + slice_str = ndslice_cpp( + subscript.split(', '), self.constants[target].shape) + newnode = ast.parse('%s[%s]' % (target, slice_str)).body[0].value + else: + newnode = ast.parse('__%s(%s)' % (target, subscript)).body[0].value + return ast.copy_location(newnode, node) + + def visit_Expr(self, node): + # Check for DaCe function calls + if isinstance(node.value, ast.Call): + # Some calls should not be parsed + if rname(node.value.func) == "define_local": + return None + elif rname(node.value.func) == "define_local_scalar": + return None + elif rname(node.value.func) == "define_stream": + return None + elif rname(node.value.func) == "define_streamarray": + return None + + return self.generic_visit(node) + + def visit_FunctionDef(self, node): + # Do not parse internal functions + return None + + # Replace default modules (e.g., math) with dace::math:: + def visit_Attribute(self, node): + attrname = rname(node) + module_name = attrname[:attrname.rfind('.')] + func_name = attrname[attrname.rfind('.') + 1:] + if module_name in types._ALLOWED_MODULES: + cppmodname = types._ALLOWED_MODULES[module_name] + return ast.copy_location( + ast.Name(id=(cppmodname + func_name), ctx=ast.Load), node) + return self.generic_visit(node) + + +def unique(seq): + seen = set() + return [x for x in seq if not (x in seen or seen.add(x))] + + +# TODO: This should be in the CUDA code generator. Add appropriate conditions to node dispatch predicate +def presynchronize_streams(sdfg, dfg, state_id, node, callsite_stream): + state_dfg = sdfg.nodes()[state_id] + if hasattr(node, '_cuda_stream') or is_devicelevel(sdfg, state_dfg, node): + return + for e in state_dfg.in_edges(node): + if hasattr(e.src, '_cuda_stream'): + cudastream = 'dace::cuda::__streams[%d]' % e.src._cuda_stream + callsite_stream.write('cudaStreamSynchronize(%s);' % cudastream, + sdfg, state_id, [e.src, e.dst]) + + +# TODO: This should be in the CUDA code generator. Add appropriate conditions to node dispatch predicate +def synchronize_streams(sdfg, dfg, state_id, node, scope_exit, + callsite_stream): + # Post-kernel stream synchronization (with host or other streams) + max_streams = Config.get('compiler', 'cuda', 'max_concurrent_streams') + if max_streams >= 0: + cudastream = 'dace::cuda::__streams[%d]' % node._cuda_stream + for edge in dfg.out_edges(scope_exit): + # Synchronize end of kernel with output data (multiple kernels + # lead to same data node) + if (isinstance(edge.dst, nodes.AccessNode) + and edge.dst._cuda_stream != node._cuda_stream): + callsite_stream.write( + '''cudaEventRecord(dace::cuda::__events[{ev}], {src_stream}); +cudaStreamWaitEvent(dace::cuda::__streams[{dst_stream}], dace::cuda::__events[{ev}], 0);''' + .format( + ev=edge._cuda_event, + src_stream=cudastream, + dst_stream=edge.dst._cuda_stream), sdfg, state_id, + [edge.src, edge.dst]) + continue + + # We need the streams leading out of the output data + for e in dfg.out_edges(edge.dst): + if isinstance(e.dst, nodes.AccessNode): + continue + # If no stream at destination: synchronize stream with host. + if not hasattr(e.dst, '_cuda_stream'): + pass + # Done at destination + + # If different stream at destination: record event and wait + # for it in target stream. + elif e.dst._cuda_stream != node._cuda_stream: + callsite_stream.write( + '''cudaEventRecord(dace::cuda::__events[{ev}], {src_stream}); + cudaStreamWaitEvent(dace::cuda::__streams[{dst_stream}], dace::cuda::__events[{ev}], 0);''' + .format( + ev=e._cuda_event, + src_stream=cudastream, + dst_stream=e.dst._cuda_stream), sdfg, state_id, + [e.src, e.dst]) + # Otherwise, no synchronization necessary diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py new file mode 100644 index 0000000000..b6abdd8727 --- /dev/null +++ b/dace/codegen/targets/cuda.py @@ -0,0 +1,1794 @@ +from six import StringIO +import ast +import ctypes +import functools +import os +import sympy + +import dace +from dace.frontend import operations +from dace import subsets, symbolic, types +from dace.config import Config +from dace.graph import nodes +from dace.sdfg import ScopeSubgraphView, SDFG, SDFGState, scope_contains_scope, is_devicelevel +from dace.codegen.codeobject import CodeObject +from dace.codegen.prettycode import CodeIOStream +from dace.codegen.targets.target import (TargetCodeGenerator, IllegalCopy, + make_absolute, DefinedType) +from dace.codegen.targets.cpu import (sym2cpp, unparse_cr, cpp_array_expr, + is_array_stream_view, + synchronize_streams) +from dace.codegen.targets.framecode import _set_default_schedule_and_storage_types +from dace.properties import LambdaProperty + +from dace.codegen import cppunparse + +_SPECIAL_RTYPES = { + types.ReductionType.Min_Location: 'ArgMin', + types.ReductionType.Max_Location: 'ArgMax', +} + + +def prod(iterable): + return functools.reduce(sympy.mul.Mul, iterable, 1) + + +def _expr(val): + if isinstance(val, symbolic.SymExpr): + return val.expr + return val + + +class CUDACodeGen(TargetCodeGenerator): + """ GPU (CUDA) code generator. """ + target_name = 'cuda' + title = 'CUDA' + language = 'cu' + + def __init__(self, frame_codegen, sdfg): + self._frame = frame_codegen + self._dispatcher = frame_codegen.dispatcher + dispatcher = self._dispatcher + + self._in_device_code = False + self._cpu_codegen = None + self._block_dims = None + self._codeobject = CodeObject(sdfg.name + '_' + 'cuda', '', 'cu', + CUDACodeGen, 'CUDA') + self._localcode = CodeIOStream() + self._globalcode = CodeIOStream() + self._initcode = CodeIOStream() + self._exitcode = CodeIOStream() + self._global_sdfg = sdfg + self._toplevel_schedule = None + + # Keep track of current "scope entry/exit" code streams for extra + # code generation + self.scope_entry_stream = self._initcode + self.scope_exit_stream = self._exitcode + + # Annotate CUDA streams and events + self._cuda_streams, self._cuda_events = self._compute_cudastreams(sdfg) + + # Register dispatchers + self._cpu_codegen = dispatcher.get_generic_node_dispatcher() + + # Register additional CUDA dispatchers + dispatcher.register_map_dispatcher(types.GPU_SCHEDULES, self) + + dispatcher.register_node_dispatcher( + self, CUDACodeGen.node_dispatch_predicate) + + dispatcher.register_state_dispatcher(self, + self.state_dispatch_predicate) + + gpu_storage = [ + types.StorageType.GPU_Global, types.StorageType.GPU_Shared, + types.StorageType.GPU_Stack + ] + dispatcher.register_array_dispatcher(gpu_storage, self) + dispatcher.register_array_dispatcher(types.StorageType.CPU_Pinned, + self) + + for storage in gpu_storage: + for other_storage in types.StorageType: + dispatcher.register_copy_dispatcher(storage, other_storage, + None, self) + dispatcher.register_copy_dispatcher(other_storage, storage, + None, self) + + # Register illegal copies + cpu_unpinned_storage = [ + types.StorageType.CPU_Heap, types.StorageType.CPU_Stack + ] + gpu_private_storage = [ + types.StorageType.GPU_Shared, types.StorageType.GPU_Stack + ] + illegal_copy = IllegalCopy() + for st in cpu_unpinned_storage: + for gst in gpu_private_storage: + dispatcher.register_copy_dispatcher(st, gst, None, + illegal_copy) + dispatcher.register_copy_dispatcher(gst, st, None, + illegal_copy) + for st in cpu_unpinned_storage: + for sched_type in [ + types.ScheduleType.GPU_Device, + types.ScheduleType.GPU_ThreadBlock + ]: + dispatcher.register_copy_dispatcher( + st, types.StorageType.Register, sched_type, illegal_copy) + dispatcher.register_copy_dispatcher( + types.StorageType.Register, st, sched_type, illegal_copy) + # End of illegal copies + # End of dispatcher registration + ###################################### + + # Generate final code + def get_generated_codeobjects(self): + fileheader = CodeIOStream() + self._frame.generate_fileheader(self._global_sdfg, fileheader) + + self._codeobject.code = """ +#include +#include + +{file_header} + +DACE_EXPORTED int __dace_init_cuda({params}); +DACE_EXPORTED void __dace_exit_cuda({params}); + +{other_globalcode} + +namespace dace {{ namespace cuda {{ + cudaStream_t __streams[{nstreams}]; + cudaEvent_t __events[{nevents}]; +}} }} + +int __dace_init_cuda({params}) {{ + int count; + + // Check that we are able to run CUDA code + if (cudaGetDeviceCount(&count) != cudaSuccess) + {{ + printf("ERROR: CUDA drivers are not configured or CUDA-capable device " + "not found\\n"); + return 1; + }} + if (count == 0) + {{ + printf("ERROR: No CUDA-capable devices found\\n"); + return 2; + }} + + // Initialize CUDA before we run the application + float *dev_X; + cudaMalloc((void **) &dev_X, 1); + + // Create CUDA streams and events + for(int i = 0; i < {nstreams}; ++i) {{ + cudaStreamCreateWithFlags(&dace::cuda::__streams[i], cudaStreamNonBlocking); + }} + for(int i = 0; i < {nevents}; ++i) {{ + cudaEventCreateWithFlags(&dace::cuda::__events[i], cudaEventDisableTiming); + }} + + {initcode} + + return 0; +}} + +void __dace_exit_cuda({params}) {{ + {exitcode} + + // Destroy CUDA streams and events + for(int i = 0; i < {nstreams}; ++i) {{ + cudaStreamDestroy(dace::cuda::__streams[i]); + }} + for(int i = 0; i < {nevents}; ++i) {{ + cudaEventDestroy(dace::cuda::__events[i]); + }} +}} + +{localcode} +""".format(params=self._global_sdfg.signature(), + initcode=self._initcode.getvalue(), + exitcode=self._exitcode.getvalue(), + other_globalcode=self._globalcode.getvalue(), + localcode=self._localcode.getvalue(), + file_header=fileheader.getvalue(), + nstreams=self._cuda_streams, + nevents=self._cuda_events) + + return [self._codeobject] + + @staticmethod + def node_dispatch_predicate(sdfg, node): + if (getattr(node, 'schedule', False) + and node.schedule in types.GPU_SCHEDULES): + return True + return False + + def state_dispatch_predicate(self, sdfg, state): + if self._toplevel_schedule in types.GPU_SCHEDULES: + return True + for node in state.sink_nodes(): + if hasattr(node, '_cuda_stream'): + return True + else: + for e in state.in_edges(node): + if hasattr(e.src, '_cuda_stream'): + return True + return False + + @property + def has_initializer(self): + return True + + @property + def has_finalizer(self): + return True + + @staticmethod + def cmake_options(): + + host_compiler = make_absolute( + Config.get("compiler", "cpu", "executable")) + compiler = make_absolute(Config.get("compiler", "cuda", "executable")) + flags = Config.get("compiler", "cuda", "args") + flags += Config.get("compiler", "cuda", "additional_args") + + # Get CUDA architectures from configuration + cuda_arch = Config.get('compiler', 'cuda', 'cuda_arch').split(',') + cuda_arch = [ca for ca in cuda_arch if ca is not None and len(ca) > 0] + + flags += ' ' + ' '.join( + '-gencode arch=compute_{arch},code=sm_{arch}'.format(arch=arch) + for arch in cuda_arch) + + options = [ + "-DCUDA_HOST_COMPILER=\"{}\"".format(host_compiler), + "-DCUDA_NVCC_FLAGS=\"{}\"".format(flags), + "-DCUDA_TOOLKIT_ROOT_DIR=\"{}\"".format( + os.path.dirname(os.path.dirname(compiler).replace('\\', '/'))) + ] + + return options + + def allocate_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + nodedesc = node.desc(sdfg) + if isinstance(nodedesc, dace.data.Stream): + return self.allocate_stream(sdfg, dfg, state_id, node, + function_stream, callsite_stream) + + result = StringIO() + arrsize = ' * '.join([ + cppunparse.pyexpr2cpp(symbolic.symstr(s)) for s in nodedesc.strides + ]) + is_dynamically_sized = any( + symbolic.issymbolic(s, sdfg.constants) for s in nodedesc.strides) + arrsize_malloc = arrsize + ' * sizeof(%s)' % nodedesc.dtype.ctype + dataname = node.data + + # Different types of GPU arrays + if nodedesc.storage == types.StorageType.GPU_Global: + result.write( + '%s *%s = nullptr;\n' % (nodedesc.dtype.ctype, dataname)) + self._dispatcher.defined_vars.add(dataname, DefinedType.Pointer) + + # Strides are left to the user's discretion + result.write('cudaMalloc(&%s, %s);\n' % (dataname, arrsize_malloc)) + if node.setzero: + result.write( + 'cudaMemset(%s, 0, %s);\n' % (dataname, arrsize_malloc)) + + elif nodedesc.storage == types.StorageType.CPU_Pinned: + result.write( + '%s *%s = nullptr;\n' % (nodedesc.dtype.ctype, dataname)) + self._dispatcher.defined_vars.add(dataname, DefinedType.Pointer) + + # Strides are left to the user's discretion + result.write( + 'cudaMallocHost(&%s, %s);\n' % (dataname, arrsize_malloc)) + if node.setzero: + result.write( + 'memset(%s, 0, %s);\n' % (dataname, arrsize_malloc)) + elif nodedesc.storage == types.StorageType.GPU_Shared: + if is_dynamically_sized: + raise NotImplementedError('Dynamic shared memory unsupported') + result.write("__shared__ %s %s[%s];\n" % (nodedesc.dtype.ctype, + dataname, arrsize)) + self._dispatcher.defined_vars.add(dataname, DefinedType.Pointer) + if node.setzero: + result.write( + 'dace::ResetShared<{type}, {block_size}, {elements}, ' + '1, false>::Reset({ptr});\n'.format( + type=nodedesc.dtype.ctype, + block_size=', '.join(_topy(self._block_dims)), + ptr=dataname, + elements=arrsize)) + elif nodedesc.storage == types.StorageType.GPU_Stack: + if is_dynamically_sized: + raise ValueError('Dynamic allocation of registers not allowed') + szstr = ' = {0}' if node.setzero else '' + result.write("%s %s[%s]%s;\n" % (nodedesc.dtype.ctype, dataname, + arrsize, szstr)) + self._dispatcher.defined_vars.add(dataname, DefinedType.Pointer) + else: + raise NotImplementedError("CUDA: Unimplemented storage type " + + str(nodedesc.storage)) + + callsite_stream.write(result.getvalue(), sdfg, state_id, node) + + def initialize_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + # No need (for now) + pass + + def allocate_stream(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + nodedesc = node.desc(sdfg) + dataname = node.data + if nodedesc.storage == types.StorageType.GPU_Global: + fmtargs = { + 'name': dataname, + 'type': nodedesc.dtype.ctype, + 'is_pow2': sym2cpp( + sympy.log(nodedesc.buffer_size, 2).is_Integer), + 'location': + '%s_%s_%s' % (sdfg.name, state_id, dfg.node_id(node)), + } + + self._dispatcher.defined_vars.add(dataname, DefinedType.Stream) + + if is_array_stream_view(sdfg, dfg, node): + fmtargs['ptr'] = nodedesc.sink + # Assuming 1D array sink/src + fmtargs['size'] = sym2cpp(sdfg.arrays[nodedesc.sink].shape[0]) + + function_stream.write( + 'DACE_EXPORTED void __dace_alloc_{location}({type} *ptr, uint32_t size, dace::GPUStream<{type}, {is_pow2}>& result);'. + format(**fmtargs), sdfg, state_id, node) + self._globalcode.write( + """ +DACE_EXPORTED void __dace_alloc_{location}({type} *ptr, uint32_t size, dace::GPUStream<{type}, {is_pow2}>& result); +void __dace_alloc_{location}({type} *ptr, uint32_t size, dace::GPUStream<{type}, {is_pow2}>& result) {{ + result = dace::AllocGPUArrayStreamView<{type}, {is_pow2}>(ptr, size); +}}""".format(**fmtargs), sdfg, state_id, node) + callsite_stream.write( + 'dace::GPUStream<{type}, {is_pow2}> {name}; __dace_alloc_{location}({ptr}, {size}, {name});'. + format(**fmtargs), sdfg, state_id, node) + else: + fmtargs['size'] = sym2cpp(nodedesc.buffer_size) + + function_stream.write( + 'DACE_EXPORTED void __dace_alloc_{location}(uint32_t size, dace::GPUStream<{type}, {is_pow2}>& result);'. + format(**fmtargs), sdfg, state_id, node) + self._globalcode.write( + """ +DACE_EXPORTED void __dace_alloc_{location}(uint32_t size, dace::GPUStream<{type}, {is_pow2}>& result); +dace::GPUStream<{type}, {is_pow2}> __dace_alloc_{location}(uint32_t size, dace::GPUStream<{type}, {is_pow2}>& result) {{ + result = dace::AllocGPUStream<{type}, {is_pow2}>({size}); +}}""".format(**fmtargs), sdfg, state_id, node) + callsite_stream.write( + 'dace::GPUStream<{type}, {is_pow2}> {name}; __dace_alloc_{location}({size}, {name});'. + format(**fmtargs), sdfg, state_id, node) + + def deallocate_stream(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + nodedesc = node.desc(sdfg) + dataname = node.data + if nodedesc.storage == types.StorageType.GPU_Global: + if is_array_stream_view(sdfg, dfg, node): + callsite_stream.write( + 'dace::FreeGPUArrayStreamView(%s);' % dataname, sdfg, + state_id, node) + else: + callsite_stream.write('dace::FreeGPUStream(%s);' % dataname, + sdfg, state_id, node) + + def deallocate_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + nodedesc = node.desc(sdfg) + dataname = node.data + if isinstance(nodedesc, dace.data.Stream): + return self.deallocate_stream(sdfg, dfg, state_id, node, + function_stream, callsite_stream) + + if nodedesc.storage == types.StorageType.GPU_Global: + callsite_stream.write('cudaFree(%s);\n' % dataname, sdfg, state_id, + node) + elif nodedesc.storage == types.StorageType.CPU_Pinned: + callsite_stream.write('cudaFreeHost(%s);\n' % dataname, sdfg, + state_id, node) + elif nodedesc.storage == types.StorageType.GPU_Shared or \ + nodedesc.storage == types.StorageType.GPU_Stack: + pass # Do nothing + else: + raise NotImplementedError + + def _compute_cudastreams(self, + sdfg: SDFG, + default_stream=0, + default_event=0): + """ Annotates an SDFG (and all nested ones) to include a `_cuda_stream` + field. This field is applied to all GPU maps, tasklets, and copies + that can be executed in parallel. + @param sdfg: The sdfg to modify. + @param default_stream: The stream ID to start counting from (used + in recursion to nested SDFGs). + @param default_event: The event ID to start counting from (used + in recursion to nested SDFGs). + @return: 2-tuple of the number of streams, events to create. + """ + concurrent_streams = Config.get('compiler', 'cuda', + 'max_concurrent_streams') + if concurrent_streams < 0: + return 0, 0 + + def increment(streams): + if concurrent_streams > 0: + return (streams + 1) % concurrent_streams + return streams + 1 + + state_streams = [] + state_subsdfg_events = [] + + for state in sdfg.nodes(): + # Start by annotating source nodes + source_nodes = state.source_nodes() + + # Concurrency can only be found in each state + max_streams = default_stream + max_events = default_event + + for i, node in enumerate(source_nodes): + if isinstance(node, nodes.AccessNode): + continue + if isinstance(node, nodes.NestedSDFG): + if node.schedule == types.ScheduleType.GPU_Device: + continue + node._cuda_stream = max_streams + node._cs_childpath = False + max_streams = increment(max_streams) + + # Maintain the same CUDA stream in DFS order, add more when + # possible. + for e in state.dfs_edges(source_nodes): + if hasattr(e.dst, '_cuda_stream'): + continue + if hasattr(e.src, '_cuda_stream'): + c = e.src._cuda_stream + if e.src._cs_childpath == True: + c = max_streams + max_streams = increment(max_streams) + e.src._cs_childpath = True + else: + c = max_streams + max_streams = increment(max_streams) + e.dst._cuda_stream = c + if not hasattr(e.dst, '_cs_childpath'): + e.dst._cs_childpath = False + if isinstance(e.dst, nodes.NestedSDFG): + if e.dst.schedule not in types.GPU_SCHEDULES: + max_streams, max_events = self._compute_cudastreams( + e.dst.sdfg, e.dst._cuda_stream, max_events + 1) + + state_streams.append(max_streams if concurrent_streams == 0 else + concurrent_streams) + state_subsdfg_events.append(max_events) + + # Remove CUDA streams from paths of non-gpu copies and CPU tasklets + for node, graph in sdfg.all_nodes_recursive(): + if isinstance(graph, SDFGState): + cur_sdfg = graph.parent + for e in graph.out_edges(node): + path = graph.memlet_path(e) + # If leading from/to a GPU memory node, keep stream + if ((isinstance(path[0].src, nodes.AccessNode) + and path[0].src.desc( + cur_sdfg).storage == types.StorageType.GPU_Global) + or (isinstance(path[-1].dst, nodes.AccessNode) + and path[-1].dst.desc(cur_sdfg).storage == + types.StorageType.GPU_Global)): + break + # If leading from/to a GPU tasklet, keep stream + if ((isinstance(path[0].src, nodes.CodeNode) + and is_devicelevel(cur_sdfg, graph, path[0].src)) or + (isinstance(path[-1].dst, nodes.CodeNode) + and is_devicelevel(cur_sdfg, graph, path[-1].dst))): + break + # If leading from/to a GPU reduction, keep stream + if ((isinstance(path[0].src, nodes.Reduce) and + path[0].src.schedule == types.ScheduleType.GPU_Device) + or + (isinstance(path[-1].dst, nodes.Reduce) and path[-1] + .dst.schedule == types.ScheduleType.GPU_Device)): + break + else: # If we did not break, we do not need a CUDA stream + if hasattr(node, '_cuda_stream'): + delattr(node, '_cuda_stream') + # In any case, remove childpath + if hasattr(node, '_cs_childpath'): + delattr(node, '_cs_childpath') + + # Compute maximal number of events by counting edges (within the same + # state) that point from one stream to another + state_events = [] + for i, state in enumerate(sdfg.nodes()): + events = state_subsdfg_events[i] + + for e in state.edges(): + if hasattr(e.src, '_cuda_stream'): + # If there are two or more CUDA streams involved in this + # edge, or the destination is unrelated to CUDA + if (not hasattr(e.dst, '_cuda_stream') + or e.src._cuda_stream != e.dst._cuda_stream): + for mpe in state.memlet_path(e): + mpe._cuda_event = events + events += 1 + + state_events.append(events) + + # Maximum over all states + max_streams = max(state_streams) + max_events = max(state_events) + + return max_streams, max_events + + def _emit_copy(self, state_id, src_node, src_storage, dst_node, + dst_storage, dst_schedule, edge, sdfg, dfg, + callsite_stream): + u, uconn, v, vconn, memlet = edge + state_dfg = sdfg.nodes()[state_id] + + cpu_storage_types = [ + types.StorageType.CPU_Heap, types.StorageType.CPU_Stack, + types.StorageType.CPU_Pinned + ] + gpu_storage_types = [ + types.StorageType.GPU_Global, types.StorageType.GPU_Shared, + types.StorageType.GPU_Stack + ] + + copy_shape = memlet.subset.bounding_box_size() + copy_shape = [symbolic.overapproximate(s) for s in copy_shape] + # Determine directionality + if (isinstance(src_node, nodes.AccessNode) + and memlet.data == src_node.data): + outgoing_memlet = True + elif (isinstance(dst_node, nodes.AccessNode) + and memlet.data == dst_node.data): + outgoing_memlet = False + else: + raise LookupError('Memlet does not point to any of the nodes') + + if (isinstance(src_node, nodes.AccessNode) + and isinstance(dst_node, nodes.AccessNode) + and not self._in_device_code + and (src_storage == types.StorageType.GPU_Global + or dst_storage == types.StorageType.GPU_Global)): + src_location = 'Device' if src_storage == types.StorageType.GPU_Global else 'Host' + dst_location = 'Device' if dst_storage == types.StorageType.GPU_Global else 'Host' + + syncwith = {} # Dictionary of {stream: event} + is_sync = False + max_streams = Config.get('compiler', 'cuda', + 'max_concurrent_streams') + + if hasattr(src_node, '_cuda_stream'): + cudastream = src_node._cuda_stream + if not hasattr(dst_node, '_cuda_stream'): + # Copy after which data is needed by the host + is_sync = True + elif dst_node._cuda_stream != src_node._cuda_stream: + syncwith[dst_node._cuda_stream] = edge._cuda_event + else: + pass # Otherwise, no need to synchronize + elif hasattr(dst_node, '_cuda_stream'): + cudastream = dst_node._cuda_stream + else: + if max_streams >= 0: + print('WARNING: Undefined stream, reverting to default') + if dst_location == 'Host': + is_sync = True + cudastream = 'nullptr' + + # Handle case of impending kernel/tasklet on another stream + if max_streams >= 0: + for e in state_dfg.out_edges(dst_node): + if isinstance(e.dst, nodes.AccessNode): + continue + if not hasattr(e.dst, '_cuda_stream'): + is_sync = True + elif e.dst._cuda_stream != cudastream: + syncwith[e.dst._cuda_stream] = e._cuda_event + + if cudastream != 'nullptr': + cudastream = 'dace::cuda::__streams[%d]' % cudastream + + if memlet.wcr is not None: + raise NotImplementedError('Accumulate %s to %s not implemented' + % (src_location, dst_location)) + ############################# + + # Obtain copy information + copy_shape, src_strides, dst_strides, src_expr, dst_expr = ( + self._cpu_codegen.memlet_copy_to_absolute_strides( + sdfg, memlet, src_node, dst_node)) + + dims = len(copy_shape) + + # Handle unsupported copy types + if dims == 2 and (src_strides[-1] != 1 or dst_strides[-1] != 1): + raise NotImplementedError('2D copy only supported with one ' + 'stride') + + # Currently we only support ND copies when they can be represented + # as a 1D copy or as a 2D strided copy + if dims > 2: + raise NotImplementedError('Copies between CPU and GPU are not' + ' supported for N-dimensions') + + if dims == 1: + copysize = ' * '.join([ + cppunparse.pyexpr2cpp(symbolic.symstr(s)) + for s in copy_shape + ]) + array_length = copysize + copysize += ' * sizeof(%s)' % dst_node.desc(sdfg).dtype.ctype + + callsite_stream.write( + 'cudaMemcpyAsync(%s, %s, %s, cudaMemcpy%sTo%s, %s);\n' % + (dst_expr, src_expr, copysize, src_location, dst_location, + cudastream), sdfg, state_id, [src_node, dst_node]) + node_dtype = dst_node.desc(sdfg).dtype + if issubclass(node_dtype.type, ctypes.Structure): + callsite_stream.write( + 'for (auto __idx = 0; __idx < {arrlen}; ++__idx) ' + '{{'.format(arrlen=str(array_length))) + for field_name, field_type in node_dtype._data.items(): + if isinstance(field_type, types.pointer): + tclass = field_type.type + length = node_dtype._length[field_name] + size = 'sizeof({})*{}[__idx].{}'.format( + types._CTYPES[tclass], str(src_node), length) + callsite_stream.write( + 'cudaMalloc(&{dst}[__idx].{fname}, ' + '{sz});'.format( + dst=str(dst_node), + fname=field_name, + sz=size)) + callsite_stream.write( + 'cudaMemcpyAsync({dst}[__idx].{fname}, ' + '{src}[__idx].{fname}, {sz}, ' + 'cudaMemcpy{sloc}To{dloc}, {stream});'.format( + dst=str(dst_node), + src=str(src_node), + fname=field_name, + sz=size, + sloc=src_location, + dloc=dst_location, + stream=cudastream), sdfg, state_id, + [src_node, dst_node]) + callsite_stream.write('}') + elif dims == 2: + callsite_stream.write( + 'cudaMemcpy2DAsync(%s, %s, %s, %s, %s, %s, cudaMemcpy%sTo%s, %s);\n' + % (dst_expr, _topy(dst_strides[0]) + + ' * sizeof(%s)' % dst_node.desc(sdfg).dtype.ctype, + src_expr, sym2cpp(src_strides[0]) + + ' * sizeof(%s)' % src_node.desc(sdfg).dtype.ctype, + sym2cpp(copy_shape[1]) + + ' * sizeof(%s)' % dst_node.desc(sdfg).dtype.ctype, + sym2cpp(copy_shape[0]), src_location, dst_location, + cudastream), sdfg, state_id, [src_node, dst_node]) + + # Post-copy synchronization + if is_sync: + # Synchronize with host (done at destination) + pass + else: + # Synchronize with other streams as necessary + for streamid, event in syncwith.items(): + syncstream = 'dace::cuda::__streams[%d]' % streamid + callsite_stream.write( + ''' + cudaEventRecord(dace::cuda::__events[{ev}], {src_stream}); + cudaStreamWaitEvent({dst_stream}, dace::cuda::__events[{ev}], 0); + '''.format( + ev=event, + src_stream=cudastream, + dst_stream=syncstream), sdfg, state_id, + [src_node, dst_node]) + + # Copy within the GPU + elif (src_storage in gpu_storage_types + and dst_storage in gpu_storage_types): + + state_dfg = sdfg.nodes()[state_id] + sdict = state_dfg.scope_dict() + if scope_contains_scope(sdict, src_node, dst_node): + inner_schedule = dst_schedule + else: + inner_schedule = sdict[src_node] + if inner_schedule is not None: + inner_schedule = inner_schedule.map.schedule + if inner_schedule is None: # Top-level schedule + inner_schedule = self._toplevel_schedule + + # Collaborative load + if inner_schedule == types.ScheduleType.GPU_Device: + # Obtain copy information + copy_shape, src_strides, dst_strides, src_expr, dst_expr = ( + self._cpu_codegen.memlet_copy_to_absolute_strides( + sdfg, memlet, src_node, dst_node)) + + dims = len(copy_shape) + + funcname = 'dace::%sTo%s%dD' % (_get_storagename(src_storage), + _get_storagename(dst_storage), + dims) + + accum = '' + custom_reduction = [] + if memlet.wcr is not None: + redtype = operations.detect_reduction_type(memlet.wcr) + reduction_tmpl = '' + # Special call for detected reduction types + if redtype != types.ReductionType.Custom: + credtype = ('dace::ReductionType::' + + str(redtype)[str(redtype).find('.') + 1:]) + reduction_tmpl = '<%s>' % credtype + else: + custom_reduction = [unparse_cr(memlet.wcr)] + accum = '::template Accum%s' % reduction_tmpl + + if any( + symbolic.issymbolic(s, sdfg.constants) + for s in copy_shape): + callsite_stream.write(( + ' {func}Dynamic, {bdims}, ' + + '{dststrides}, {is_async}>{accum}({args});').format( + func=funcname, + type=dst_node.desc(sdfg).dtype.ctype, + veclen=memlet.veclen, + bdims=', '.join(_topy(self._block_dims)), + dststrides=', '.join(_topy(dst_strides)), + is_async='false' + if state_dfg.out_degree(dst_node) > 0 else 'true', + accum=accum, + args=', '.join([src_expr] + _topy(src_strides) + + [dst_expr] + custom_reduction + + _topy(copy_shape))), sdfg, state_id, + [src_node, dst_node]) + else: + callsite_stream.write(( + ' {func}, {bdims}, {copysize}, ' + + '{dststrides}, {is_async}>{accum}({args});').format( + func=funcname, + type=dst_node.desc(sdfg).dtype.ctype, + veclen=memlet.veclen, + bdims=', '.join(_topy(self._block_dims)), + copysize=', '.join(_topy(copy_shape)), + dststrides=', '.join(_topy(dst_strides)), + is_async='false' + if state_dfg.out_degree(dst_node) > 0 else 'true', + accum=accum, + args=', '.join([src_expr] + _topy(src_strides) + + [dst_expr] + custom_reduction)), + sdfg, state_id, [src_node, dst_node]) + # Per-thread load (same as CPU copies) + else: + self._cpu_codegen.copy_memory(sdfg, dfg, state_id, src_node, + dst_node, edge, None, + callsite_stream) + else: + self._cpu_codegen.copy_memory(sdfg, dfg, state_id, src_node, + dst_node, edge, None, + callsite_stream) + + def copy_memory(self, sdfg, dfg, state_id, src_node, dst_node, memlet, + function_stream, callsite_stream): + if isinstance(src_node, nodes.Tasklet): + src_storage = types.StorageType.Register + src_parent = dfg.scope_dict()[src_node] + dst_schedule = None if src_parent is None else src_parent.map.schedule + else: + src_storage = src_node.desc(sdfg).storage + + if isinstance(dst_node, nodes.Tasklet): + dst_storage = types.StorageType.Register + else: + dst_storage = dst_node.desc(sdfg).storage + + dst_parent = dfg.scope_dict()[dst_node] + dst_schedule = None if dst_parent is None else dst_parent.map.schedule + + # Emit actual copy + self._emit_copy(state_id, src_node, src_storage, dst_node, dst_storage, + dst_schedule, memlet, sdfg, dfg, callsite_stream) + + def generate_state(self, sdfg, state, function_stream, callsite_stream): + # Two modes: device-level state and if this state has active streams + if self._toplevel_schedule in types.GPU_SCHEDULES: + self.generate_devicelevel_state(sdfg, state, function_stream, + callsite_stream) + else: + # Active streams found. Generate state normally and sync with the + # streams in the end + self._frame.generate_state( + sdfg, + state, + function_stream, + callsite_stream, + generate_state_footer=False) + if state.nosync == False: + streams_to_sync = set() + for node in state.sink_nodes(): + if hasattr(node, '_cuda_stream'): + streams_to_sync.add(node._cuda_stream) + else: + # Synchronize sink-node copies at the end of the state + for e in state.in_edges(node): + if hasattr(e.src, '_cuda_stream'): + streams_to_sync.add(e.src._cuda_stream) + for stream in streams_to_sync: + callsite_stream.write( + 'cudaStreamSynchronize(dace::cuda::__streams[%d]);' % + stream, sdfg, sdfg.node_id(state)) + + # After synchronizing streams, generate state footer normally + + # Emit internal transient array deallocation + sid = sdfg.node_id(state) + data_to_allocate = (set(state.top_level_transients()) - set( + sdfg.shared_transients())) + deallocated = set() + for node in state.data_nodes(): + if node.data not in data_to_allocate or node.data in deallocated: + continue + deallocated.add(node.data) + self._frame._dispatcher.dispatch_deallocate( + sdfg, state, sid, node, function_stream, callsite_stream) + + def generate_devicelevel_state(self, sdfg, state, function_stream, + callsite_stream): + + # Special case: if this is a GPU grid state and something is reading + # from a possible result of a collaborative write, sync first + if self._toplevel_schedule == types.ScheduleType.GPU_Device: + state_id = next( + i for i, s in enumerate(sdfg.nodes()) if s == state) + for node in state.nodes(): + if (isinstance(node, nodes.AccessNode) and + node.desc(sdfg).storage == types.StorageType.GPU_Shared + and state.in_degree(node) == 0 + and state.out_degree(node) > 0): + callsite_stream.write('__syncthreads();', sdfg, state_id) + break + + self._frame.generate_state(sdfg, state, function_stream, + callsite_stream) + + # NOTE: This function is ONLY called from the CPU side. Therefore, any + # schedule that is out of the ordinary will raise an exception + def generate_scope(self, sdfg, dfg_scope, state_id, function_stream, + callsite_stream): + scope_entry = dfg_scope.source_nodes()[0] + scope_exit = dfg_scope.sink_nodes()[0] + + dfg = sdfg.nodes()[state_id] + + # If in device-level code, call appropriate function + if (self._toplevel_schedule == types.ScheduleType.GPU_Device or + (dfg.scope_dict()[scope_entry] is not None and dfg.scope_dict() + [scope_entry].map.schedule in types.GPU_SCHEDULES)): + self.generate_devicelevel_scope(sdfg, dfg_scope, state_id, + function_stream, callsite_stream) + return + + # If not device-level code, ensure the schedule is correct + if scope_entry.map.schedule != types.ScheduleType.GPU_Device: + raise TypeError('Cannot schedule %s directly from non-GPU code' % + str(scope_entry.map.schedule)) + + # Determine whether to create a global (grid) barrier object + create_grid_barrier = False + for node in dfg_scope.nodes(): + if scope_entry == node: continue + if (isinstance(node, nodes.EntryNode) + and node.map.schedule == types.ScheduleType.GPU_Device): + create_grid_barrier = True + + kernel_name = '%s_%d_%d' % ( + scope_entry.map.label, dfg.node_id(scope_entry), sdfg.node_id(dfg)) + + # Get parameters from input/output memlets to this map + params = set(d.data for node in dfg_scope.source_nodes() for _,_,_,_,d in dfg.in_edges(node)) | \ + set(d.data for node in dfg_scope.sink_nodes() for _,_,_,_,d in dfg.out_edges(node)) + + # Get symbolic parameters (free symbols) for kernel + syms = sdfg.symbols_defined_at(scope_entry) + freesyms = { + k: v + for k, v in syms.items() + if k not in sdfg.constants and k not in scope_entry.map.params + } + symbol_sigs = [ + v.dtype.ctype + ' ' + k for k, v in sorted(freesyms.items()) + ] + symbol_names = [k for k in sorted(freesyms.keys())] + + # Hijack symbol_sigs to create a grid barrier object + if create_grid_barrier: + symbol_sigs.append('cub::GridBarrier __gbar') + + # Comprehend grid/block dimensions from scopes + grid_dims, block_dims, tbmap = self.get_kernel_dimensions(dfg_scope) + + kernel_args = [ + sdfg.arrays[p].signature(False, name=p) for p in sorted(params) + ] + symbol_names + kernel_args_typed = [ + sdfg.arrays[p].signature(name=p) for p in sorted(params) + ] + symbol_sigs + + # Store init/exit code streams + old_entry_stream = self.scope_entry_stream + old_exit_stream = self.scope_exit_stream + self.scope_entry_stream = CodeIOStream() + self.scope_exit_stream = CodeIOStream() + + kernel_stream = CodeIOStream() + self.generate_kernel_scope(sdfg, dfg_scope, state_id, scope_entry.map, + kernel_name, grid_dims, block_dims, tbmap, + kernel_args_typed, self._globalcode, + kernel_stream) + + # Write kernel prototype + node = dfg_scope.source_nodes()[0] + self._localcode.write( + '__global__ void %s(%s) {\n' % + (kernel_name, ', '.join(kernel_args_typed)), sdfg, state_id, node) + + # Write constant expressions in GPU code + self._frame.generate_constants(sdfg, self._localcode) + + self._localcode.write(self.scope_entry_stream.getvalue()) + + # Assuming kernel can write to global scope (function_stream), we + # output the kernel last + self._localcode.write(kernel_stream.getvalue() + '\n') + + self._localcode.write(self.scope_exit_stream.getvalue()) + + # Restore init/exit code streams + self.scope_entry_stream = old_entry_stream + self.scope_exit_stream = old_exit_stream + + # Write callback function definition + self._localcode.write( + """ +DACE_EXPORTED void __dace_runkernel_{fname}({fargs}); +void __dace_runkernel_{fname}({fargs}) +{{ +""".format(fname=kernel_name, fargs=', '.join(kernel_args_typed)), sdfg, + state_id, node) + + if create_grid_barrier: + gbar = '__gbar_' + kernel_name + self._localcode.write(' cub::GridBarrierLifetime %s;\n' % gbar, + sdfg, state_id, node) + self._localcode.write( + ' %s.Setup(%s);\n' % (gbar, ' * '.join(_topy(grid_dims))), + sdfg, state_id, node) + symbol_names.append(gbar) + + # Compute dynamic shared memory + dynsmem_size = 0 + # For all access nodes, if array storage == GPU_Shared and size is + # symbolic, add it. If nested SDFG, check all internal arrays + for node in dfg_scope.nodes(): + if isinstance(node, nodes.AccessNode): + arr = sdfg.arrays[node.data] + if arr.storage == types.StorageType.GPU_Shared: + numel = functools.reduce(lambda a, b: a * b, arr.shape) + if symbolic.issymbolic(numel, sdfg.constants): + dynsmem_size += numel + elif isinstance(node, nodes.NestedSDFG): + for arr in node.sdfg.arrays_recursive(): + if (arr is not None + and arr.storage == types.StorageType.GPU_Shared): + numel = functools.reduce(lambda a, b: a * b, arr.shape) + if symbolic.issymbolic(numel, sdfg.constants): + dynsmem_size += numel + + max_streams = Config.get('compiler', 'cuda', 'max_concurrent_streams') + if max_streams >= 0: + cudastream = 'dace::cuda::__streams[%d]' % scope_entry._cuda_stream + else: + cudastream = 'nullptr' + + self._localcode.write( + ''' +void *{kname}_args[] = {{ {kargs} }}; +cudaLaunchKernel((void*){kname}, dim3({gdims}), dim3({bdims}), {kname}_args, {dynsmem}, {stream});''' + .format( + kname=kernel_name, + kargs=', '.join(['(void *)&' + arg for arg in kernel_args]), + gdims=','.join(_topy(grid_dims)), + bdims=','.join(_topy(block_dims)), + dynsmem=_topy(dynsmem_size), + stream=cudastream), sdfg, state_id, node) + + # Close the runkernel function + self._localcode.write('}') + ####################### + # Add invocation to calling code (in another file) + function_stream.write( + 'DACE_EXPORTED void __dace_runkernel_%s(%s);\n' % + (kernel_name, ', '.join(kernel_args_typed)), sdfg, state_id, node) + callsite_stream.write( + '__dace_runkernel_%s(%s);\n' % + (kernel_name, ', '.join(kernel_args)), sdfg, state_id, node) + + synchronize_streams(sdfg, dfg, state_id, node, scope_exit, + callsite_stream) + + def get_kernel_dimensions(self, dfg_scope): + """ Determines a CUDA kernel's grid/block dimensions from map + scopes. + + Ruleset for kernel dimensions: + 1. If only one map (device-level) exists, of an integer set S, + the block size is 32x1x1 and grid size is ceil(|S|/32) in + 1st dimension. + 2. If nested thread-block maps exist (T_1,...,T_n), grid + size is |S| and block size is max(|T_1|,...,|T_n|) with + block specialization. + 3. If block size can be overapproximated, it is (for + dynamically-sized blocks that are bounded by a + predefined size). + + @note: Kernel dimensions are separate from the map + variables, and they should be treated as such. + @note: To make use of the grid/block 3D registers, we use multi- + dimensional kernels up to 3 dimensions, and flatten the + rest into the third dimension. + """ + + kernelmap_entry = dfg_scope.source_nodes()[0] + grid_size = kernelmap_entry.map.range.size(True)[::-1] + block_size = None + + # Linearize (flatten) rest of dimensions to third + if len(grid_size) > 3: + grid_size[2] = functools.reduce(sympy.mul.Mul, grid_size[2:], 1) + del grid_size[3:] + + # Extend to 3 dimensions if necessary + grid_size = grid_size + [1] * (3 - len(grid_size)) + + # Obtain thread-block maps for case (2) + tb_maps = [ + node.map for node, parent in dfg_scope.scope_dict().items() + if parent == kernelmap_entry and isinstance(node, nodes.EntryNode) + and node.schedule == types.ScheduleType.GPU_ThreadBlock + ] + # Append thread-block maps from nested SDFGs + for node in dfg_scope.scope_subgraph(kernelmap_entry).nodes(): + if isinstance(node, nodes.NestedSDFG): + _set_default_schedule_and_storage_types( + node.sdfg, node.schedule) + + tb_maps.extend([ + n.map for state in node.sdfg.nodes() + for n in state.nodes() if isinstance(n, nodes.MapEntry) + and n.schedule == types.ScheduleType.GPU_ThreadBlock + ]) + + # Case (1): no thread-block maps + if len(tb_maps) == 0: + + print('WARNING: Thread-block maps not found in kernel, assuming ' + + 'block size of (%s)' % + Config.get('compiler', 'cuda', 'default_block_size')) + block_size = [ + int(b) for b in Config.get('compiler', 'cuda', + 'default_block_size').split(',') + ] + assert (len(block_size) >= 1 and len(block_size) <= 3) + + int_ceil = sympy.Function('int_ceil') + + # Grid size = ceil(|S|/32) for first dimension, rest = |S| + grid_size = [ + int_ceil(gs, bs) for gs, bs in zip(grid_size, block_size) + ] + + return grid_size, block_size, False + + # Find all thread-block maps to determine overall block size + block_size = [1, 1, 1] + detected_block_sizes = [block_size] + for tbmap in tb_maps: + tbsize = tbmap.range.size()[::-1] + + # Over-approximate block size (e.g. min(N,(i+1)*32)-i*32 --> 32) + # The partial trailing thread-block is emitted as an if-condition + # that returns on some of the participating threads + tbsize = [symbolic.overapproximate(s) for s in tbsize] + + # Linearize (flatten) rest of dimensions to third + if len(tbsize) > 3: + tbsize[2] = functools.reduce(sympy.mul.Mul, tbsize[2:], 1) + del tbsize[3:] + + # Extend to 3 dimensions if necessary + tbsize = tbsize + [1] * (len(block_size) - len(tbsize)) + + block_size = [ + sympy.Max(sz, bbsz) for sz, bbsz in zip(block_size, tbsize) + ] + if block_size != tbsize: + detected_block_sizes.append(tbsize) + + # TODO: If grid/block sizes contain elements only defined within the + # kernel, raise an invalid SDFG exception and recommend + # overapproximation. + + return grid_size, block_size, True + + def generate_kernel_scope( + self, sdfg: SDFG, dfg_scope: ScopeSubgraphView, state_id: int, + kernel_map: nodes.Map, kernel_name: str, grid_dims: list, + block_dims: list, has_tbmap: bool, kernel_params: list, + function_stream: CodeIOStream, kernel_stream: CodeIOStream): + node = dfg_scope.source_nodes()[0] + + if not node.map.flatten: + # Add more opening braces for scope exit to close + for dim in range(len(node.map.range) - 1): + kernel_stream.write('{\n', sdfg, state_id, node) + + # Generate all index arguments for kernel grid + krange = subsets.Range(kernel_map.range[::-1]) + kdims = krange.size() + dsym = [ + symbolic.symbol('__DAPB%d' % i, nonnegative=True, integer=True) + for i in range(len(krange)) + ] + bidx = krange.coord_at(dsym) + + # First three dimensions are evaluated directly + for i in range(min(len(krange), 3)): + varname = kernel_map.params[-i - 1] + + # Delinearize third dimension if necessary + if i == 2 and len(krange) > 3: + block_expr = '(blockIdx.z / (%s))' % _topy( + functools.reduce(sympy.mul.Mul, kdims[3:], 1)) + else: + block_expr = 'blockIdx.%s' % _named_idx(i) + # If we defaulted to 32 threads per block, offset by thread ID + if not has_tbmap: + block_expr = '(%s * %s + threadIdx.%s)' % ( + block_expr, _topy(block_dims[i]), _named_idx(i)) + + expr = _topy(bidx[i]).replace('__DAPB%d' % i, block_expr) + + kernel_stream.write('int %s = %s;' % (varname, expr), sdfg, + state_id, node) + self._dispatcher.defined_vars.add(varname, DefinedType.Scalar) + + # Delinearize beyond the third dimension + if len(krange) > 3: + for i in range(3, len(krange)): + varname = kernel_map.params[-i - 1] + # true dim i = z / ('*'.join(kdims[i+1:])) % kdims[i] + block_expr = '(blockIdx.z / (%s)) %% (%s)' % ( + _topy(functools.reduce(sympy.mul.Mul, kdims[i + 1:], 1)), + _topy(kdims[i]), + ) + + expr = _topy(bidx[i]).replace('__DAPB%d' % i, block_expr) + kernel_stream.write('int %s = %s;' % (varname, expr), sdfg, + state_id, node) + self._dispatcher.defined_vars.add(varname, DefinedType.Scalar) + + # Dispatch internal code + assert self._in_device_code == False + self._in_device_code = True + self._block_dims = block_dims + + # Emit internal array allocation (deallocation handled at MapExit) + scope_entry = dfg_scope.source_nodes()[0] + to_allocate = dace.sdfg.local_transients(sdfg, dfg_scope, scope_entry) + allocated = set() + for child in dfg_scope.scope_dict(node_to_children=True)[node]: + if not isinstance(child, nodes.AccessNode): + continue + if child.data not in to_allocate or child.data in allocated: + continue + allocated.add(child.data) + self._dispatcher.dispatch_allocate(sdfg, dfg_scope, state_id, + child, function_stream, + kernel_stream) + self._dispatcher.dispatch_initialize(sdfg, dfg_scope, state_id, + child, function_stream, + kernel_stream) + + # Generate conditions for this block's execution using min and max + # element, e.g., skipping out-of-bounds threads in trailing block + if has_tbmap == False: + dsym_end = [d + bs - 1 for d, bs in zip(dsym, self._block_dims)] + minels = krange.min_element() + maxels = krange.max_element() + for i, (v, minel, maxel) in enumerate( + zip(kernel_map.params[::-1], minels, maxels)): + condition = '' + + # Optimize conditions if they are always true + if i >= 3 or (dsym[i] >= minel) != True: + condition += '%s >= %s' % (v, _topy(minel)) + if i >= 3 or (dsym_end[i] < maxel) != False: + if len(condition) > 0: + condition += ' && ' + condition += '%s < %s' % (v, _topy(maxel + 1)) + if len(condition) > 0: + kernel_stream.write('if (%s) {' % condition, sdfg, + state_id, scope_entry) + else: + kernel_stream.write('{', sdfg, state_id, scope_entry) + + self._dispatcher.dispatch_subgraph( + sdfg, + dfg_scope, + state_id, + function_stream, + kernel_stream, + skip_entry_node=True) + + if has_tbmap == False: + for _ in kernel_map.params: + kernel_stream.write('}\n', sdfg, state_id, node) + + self._block_dims = None + self._in_device_code = False + + def get_next_scope_entries(self, dfg, scope_entry): + parent_scope_entry = dfg.scope_dict()[scope_entry] + # We're in a nested SDFG, use full graph + if parent_scope_entry is None: + parent_scope = dfg + else: + parent_scope = dfg.scope_subgraph(parent_scope_entry) + + # Get all non-sequential scopes from the same level + all_scopes = [ + node for node in parent_scope.topological_sort(scope_entry) + if isinstance(node, nodes.EntryNode) + and node.map.schedule != types.ScheduleType.Sequential + ] + + # TODO: Fix to include *next* scopes, without concurrent scopes + + return all_scopes[all_scopes.index(scope_entry) + 1:] + + def generate_devicelevel_scope(self, sdfg, dfg_scope, state_id, + function_stream, callsite_stream): + # Sanity check + assert self._in_device_code == True + + dfg = sdfg.nodes()[state_id] + sdict = dfg.scope_dict() + scope_entry = dfg_scope.source_nodes()[0] + scope_map = scope_entry.map + next_scopes = self.get_next_scope_entries(dfg, scope_entry) + + if scope_map.schedule == types.ScheduleType.GPU_ThreadBlock_Dynamic: + if len(scope_map.params) > 1: + raise ValueError('Only one-dimensional maps are supported for ' + 'dynamic block map schedule (got %d)' % len( + scope_map.params)) + total_block_size = 1 + for bdim in self._block_dims: + if symbolic.issymbolic(bdim, sdfg.constants): + raise ValueError( + 'Block size has to be constant for block-wide ' + 'dynamic map schedule (got %s)' % str(bdim)) + total_block_size *= bdim + if _expr(scope_map.range[0][2]) != 1: + raise NotImplementedError( + 'Skip not implemented for dynamic thread-block map schedule' + ) + + ##### TODO (later): Generalize + # Find thread-block param map and its name + if self._block_dims[1] != 1 or self._block_dims[2] != 1: + raise NotImplementedError( + 'Dynamic block map schedule only ' + 'implemented for 1D blocks currently') + pscope = sdict[scope_entry] + while pscope is not None and pscope.map.schedule != types.ScheduleType.GPU_ThreadBlock: + pscope = sdict[pscope] + if pscope is None: + raise NotImplementedError('Dynamic block map schedule ' + 'currently requires block map') + bname = pscope.map.params[0] + + callsite_stream.write( + 'dace::DynamicMap<{bsize}>::template ' + 'schedule({begin}, {end}, {tid}, [&](auto {param}, ' + 'auto {tid}) {{'.format( + bsize=total_block_size, + begin=scope_map.range[0][0], + end=scope_map.range[0][1] + 1, + param=scope_map.params[0], + tid=bname), sdfg, state_id, scope_entry) + else: + # If integer sets are used, only emit one opening curly brace + if scope_map.flatten: + callsite_stream.write('{', sdfg, state_id, scope_entry) + else: + for dim in range(len(scope_map.range)): + callsite_stream.write('{', sdfg, state_id, scope_entry) + + # Emit internal array allocation (deallocation handled at MapExit) + to_allocate = dace.sdfg.local_transients(sdfg, dfg_scope, scope_entry) + allocated = set() + for child in dfg_scope.scope_dict(node_to_children=True)[scope_entry]: + if not isinstance(child, nodes.AccessNode): + continue + if child.data not in to_allocate or child.data in allocated: + continue + allocated.add(child.data) + self._dispatcher.dispatch_allocate(sdfg, dfg_scope, state_id, + child, function_stream, + callsite_stream) + self._dispatcher.dispatch_initialize(sdfg, dfg_scope, state_id, + child, function_stream, + callsite_stream) + + # Generate all index arguments for block + if scope_map.schedule == types.ScheduleType.GPU_ThreadBlock: + brange = subsets.Range(scope_map.range[::-1]) + kdims = brange.size() + dsym = [ + symbolic.symbol( + '__DAPT%d' % i, nonnegative=True, integer=True) + for i in range(len(brange)) + ] + dsym_end = [d + bs - 1 for d, bs in zip(dsym, self._block_dims)] + tidx = brange.coord_at(dsym) + + # First three dimensions are evaluated directly + for i in range(min(len(brange), 3)): + varname = scope_map.params[-i - 1] + + # Delinearize third dimension if necessary + if i == 2 and len(brange) > 3: + block_expr = '(threadIdx.z / (%s))' % _topy( + functools.reduce(sympy.mul.Mul, kdims[3:], 1)) + else: + block_expr = 'threadIdx.%s' % _named_idx(i) + + expr = _topy(tidx[i]).replace('__DAPT%d' % i, block_expr) + callsite_stream.write('int %s = %s;' % (varname, expr), sdfg, + state_id, scope_entry) + self._dispatcher.defined_vars.add(varname, DefinedType.Scalar) + + # Delinearize beyond the third dimension + if len(brange) > 3: + for i in range(3, len(brange)): + varname = scope_map.params[-i - 1] + # true dim i = z / ('*'.join(kdims[i+1:])) % kdims[i] + block_expr = '(threadIdx.z / (%s)) %% (%s)' % ( + _topy( + functools.reduce(sympy.mul.Mul, kdims[i + 1:], 1)), + _topy(kdims[i]), + ) + + expr = _topy(tidx[i]).replace('__DAPT%d' % i, block_expr) + callsite_stream.write('int %s = %s;' % (varname, expr), + sdfg, state_id, scope_entry) + self._dispatcher.defined_vars.add(varname, + DefinedType.Scalar) + + # Generate conditions for this block's execution using min and max + # element, e.g. skipping out-of-bounds threads in trailing block + minels = brange.min_element() + maxels = brange.max_element() + for i, (v, minel, maxel) in enumerate( + zip(scope_map.params[::-1], minels, maxels)): + condition = '' + + # Optimize conditions if they are always true + if i >= 3 or (dsym[i] >= minel) != True: + condition += '%s >= %s' % (v, _topy(minel)) + if i >= 3 or (dsym_end[i] < maxel) != False: + if len(condition) > 0: + condition += ' && ' + condition += '%s < %s' % (v, _topy(maxel + 1)) + if len(condition) > 0: + callsite_stream.write('if (%s) {' % condition, sdfg, + state_id, scope_entry) + else: + callsite_stream.write('{', sdfg, state_id, scope_entry) + ########################################################## + + # Generate contents normally + self._dispatcher.dispatch_subgraph( + sdfg, + dfg_scope, + state_id, + function_stream, + callsite_stream, + skip_entry_node=True) + + # If there are any other threadblock maps down the road, + # synchronize the thread-block / grid + if len(next_scopes) > 0: + # Thread-block synchronization + if scope_entry.map.schedule == types.ScheduleType.GPU_ThreadBlock: + callsite_stream.write(' __syncthreads();\n', sdfg, state_id, + scope_entry) + # Grid synchronization (kernel fusion) + elif scope_entry.map.schedule == types.ScheduleType.GPU_Device: + callsite_stream.write(' __gbar.Sync();\n', sdfg, state_id, + scope_entry) + + def generate_node(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + if CUDACodeGen.node_dispatch_predicate(sdfg, node): + # Dynamically obtain node generator according to class name + gen = getattr(self, '_generate_' + type(node).__name__) + gen(sdfg, dfg, state_id, node, function_stream, callsite_stream) + return + + if not self._in_device_code: + self._cpu_codegen.generate_node(sdfg, dfg, state_id, node, + function_stream, callsite_stream) + return + + self._locals.clear_scope(self._code_state.indentation + 1) + + if self._in_device_code and isinstance(node, nodes.MapExit): + return # skip + + self._cpu_codegen.generate_node(sdfg, dfg, state_id, node, + function_stream, callsite_stream) + + def _generate_NestedSDFG(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + self._toplevel_schedule = node.schedule + self._cpu_codegen._generate_NestedSDFG( + sdfg, dfg, state_id, node, function_stream, callsite_stream) + + def _generate_MapExit(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + if node.map.schedule == types.ScheduleType.GPU_ThreadBlock: + # Close block invocation conditions + for i in range(len(node.map.params)): + callsite_stream.write('}', sdfg, state_id, node) + elif node.map.schedule == types.ScheduleType.GPU_ThreadBlock_Dynamic: + # Close lambda function + callsite_stream.write('});', sdfg, state_id, node) + return + + self._cpu_codegen._generate_MapExit(sdfg, dfg, state_id, node, + function_stream, callsite_stream) + + def _generate_Reduce(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + # Try to autodetect reduction type + redtype = operations.detect_reduction_type(node.wcr) + schedule = node.schedule + node_id = dfg.node_id(node) + idstr = '{sdfg}_{state}_{node}'.format( + sdfg=sdfg.name, state=state_id, node=node_id) + + output_edge = dfg.out_edges(node)[0] + output_memlet = output_edge.data + output_type = 'dace::vec<%s, %s>' % ( + sdfg.arrays[output_memlet.data].dtype.ctype, output_memlet.veclen) + + if node.identity is None: + raise ValueError('For GPU reduce nodes, initial value must be ' + 'defined') + + # Create a functor or use an existing one for reduction + if redtype == types.ReductionType.Custom: + body, arg1, arg2 = unparse_cr_split(node.wcr) + self._globalcode.write( + """ + struct __reduce_{id} {{ + template + DACE_HDFI T operator()(const T &{arg1}, const T &{arg2}) const {{ + {contents} + }} + }};""".format(id=idstr, arg1=arg1, arg2=arg2, contents=body), sdfg, + state_id, node_id) + reduce_op = ', __reduce_' + idstr + '(), ' + _topy(node.identity) + elif redtype in _SPECIAL_RTYPES: + reduce_op = '' + else: + credtype = 'dace::ReductionType::' + str( + redtype)[str(redtype).find('.') + 1:] + reduce_op = ( + (', dace::_wcr_fixed<%s, %s>()' % (credtype, output_type)) + + ', ' + _topy(node.identity)) + + # Obtain some SDFG-related information + input_data = dfg.memlet_path(dfg.in_edges(node)[0])[0].src + output_data = dfg.memlet_path(dfg.out_edges(node)[0])[-1].dst + input_memlet = dfg.in_edges(node)[0].data + reduce_shape = input_memlet.subset.bounding_box_size() + num_items = ' * '.join([_topy(s) for s in reduce_shape]) + input = (input_memlet.data + ' + ' + cpp_array_expr( + sdfg, input_memlet, with_brackets=False)) + output = (output_memlet.data + ' + ' + cpp_array_expr( + sdfg, output_memlet, with_brackets=False)) + + # Options: Device-wide reduction (even from device code), + # block-wide reduction, sequential reduction (for loop) + if node.schedule == types.ScheduleType.GPU_Device: + # Verify that data is on the GPU + if input_data.desc(sdfg).storage not in [ + types.StorageType.GPU_Global, types.StorageType.CPU_Pinned + ]: + raise ValueError('Input of GPU reduction must either reside ' + ' in global GPU memory or pinned CPU memory') + if output_data.desc(sdfg).storage not in [ + types.StorageType.GPU_Global, types.StorageType.CPU_Pinned + ]: + raise ValueError('Output of GPU reduction must either reside ' + ' in global GPU memory or pinned CPU memory') + + # TODO(later): Enable device-wide reduction from device through + # CUDA dynamic parallelism. It is disabled right now + # due to temporary memory allocation (which needs to be done + # on the host). + if self._in_device_code: + raise NotImplementedError('Device-wide reduction can only be' + ' run on non-GPU code.') + + # Determine reduction type + kname = (_SPECIAL_RTYPES[redtype] + if redtype in _SPECIAL_RTYPES else 'Reduce') + + # Create temp memory for this GPU + self._globalcode.write( + """ + void *__cub_storage_{sdfg}_{state}_{node} = NULL; + size_t __cub_ssize_{sdfg}_{state}_{node} = 0; + """.format(sdfg=sdfg.name, state=state_id, node=node_id), sdfg, + state_id, node) + + # Call CUB to get the storage size, allocate and free it + self.scope_entry_stream.write( + """ + cub::DeviceReduce::{kname}(nullptr, __cub_ssize_{sdfg}_{state}_{node}, + ({intype}*)nullptr, ({outtype}*)nullptr, {num_items}{redop}); + cudaMalloc(&__cub_storage_{sdfg}_{state}_{node}, __cub_ssize_{sdfg}_{state}_{node}); +""".format(sdfg=sdfg.name, + state=state_id, + node=node_id, + num_items=num_items, + redop=reduce_op, + intype=input_data.desc(sdfg).dtype.ctype, + outtype=output_data.desc(sdfg).dtype.ctype, + kname=kname), sdfg, state_id, node) + + self.scope_exit_stream.write( + 'cudaFree(__cub_storage_{sdfg}_{state}_{node});'.format( + sdfg=sdfg.name, state=state_id, node=node_id), sdfg, + state_id, node) + + max_streams = Config.get('compiler', 'cuda', + 'max_concurrent_streams') + if max_streams >= 0: + cudastream = 'dace::cuda::__streams[%d]' % node._cuda_stream + else: + cudastream = 'nullptr' + + # Write reduction function definition + self._localcode.write( + """ +DACE_EXPORTED void __dace_reduce_{id}({intype} *input, {outtype} *output, + size_t num_items); +void __dace_reduce_{id}({intype} *input, {outtype} *output, size_t num_items) +{{ + cub::DeviceReduce::{kname}(__cub_storage_{id}, __cub_ssize_{id}, + input, output, num_items{redop}, {stream}); +}} + """.format( + id=idstr, + intype=input_data.desc(sdfg).dtype.ctype, + outtype=output_data.desc(sdfg).dtype.ctype, + kname=kname, + redop=reduce_op, + stream=cudastream), sdfg, state_id, node) + + # Write reduction function definition in caller file + function_stream.write( + """ +DACE_EXPORTED void __dace_reduce_{id}({intype} *input, {outtype} *output, + size_t num_items); + """.format( + id=idstr, + intype=input_data.desc(sdfg).dtype.ctype, + outtype=output_data.desc(sdfg).dtype.ctype), sdfg, + state_id, node) + + # Call reduction function where necessary + input_dims = input_memlet.subset.dims() + output_dims = output_memlet.subset.data_dims() + if (node.axes is None or len(node.axes) == input_dims): + callsite_stream.write( + '__dace_reduce_{id}({input}, {output}, {num_items});' + .format( + id=idstr, + input=input, + output=output, + num_items=num_items), sdfg, state_id, node) + else: + raise NotImplementedError( + 'Multiple axis reductions not supported on GPUs. Please ' + 'apply ReduceExpansion') + # Generate for loops around CUB calls and properly offset input + # and output arrays + #for axis in range(output_dims): + # if axis not in node.axes: + # callsite_stream.write( + # 'for (int {var} = {begin}; {var} < {end}; {var} += {skip}) {{'. + # format( + # var='__o%d' % axis, + # begin=output_subset[axis][0], + # end=output_subset[axis][1] + 1, + # skip=output_subset[axis][2]), sdfg, state_id, node) + # + ### Obtain variable names per output and reduction axis + #axis_vars = [] + #octr = 0 + #for d in range(input_dims): + # if d not in axes: + # axis_vars.append('__o%d' % octr) + # octr += 1 + # + #input = (input_memlet.data.name + ' + ' + cpp_array_expr( + # sdfg, input_memlet, with_brackets=False)) + #output = (output_memlet.data.name + ' + ' + cpp_array_expr( + # sdfg, output_memlet, with_brackets=False)) + #num_items = + # + #callsite_stream.write( + # '__dace_reduce_{id}({input}, {output}, {num_items});' + # .format( + # id=idstr, + # input=input, + # output=output, + # num_items=num_items), sdfg, state_id, node) + # + ##cpp_array_expr(sdfg, + ## output_memlet, + ## offset=['__o%d' % i for i in range(output_dims)], + ## relative_offset=False)) + ## invar = cpp_array_expr(sdfg, + ## input_memlet, offset=axis_vars, relative_offset=False) + #for axis in range(output_dims): + # callsite_stream.write('}\n', sdfg, state_id, node) + return + + # Block-wide reduction + elif node.schedule == types.ScheduleType.GPU_ThreadBlock: + # Checks + if not self._in_device_code: + raise ValueError('Block-wide GPU reduction must occur within' + ' a GPU kernel') + for bdim in self._block_dims: + if symbolic.issymbolic(bdim, sdfg.constants): + raise ValueError( + 'Block size has to be constant for block-wide ' + 'reduction (got %s)' % str(bdim)) + if (node.axes is not None and len(node.axes) < input_dims): + raise ValueError( + 'Only full reduction is supported for block-wide reduce,' + ' please use ReduceExpansion') + if (input_data.desc(sdfg).storage != types.StorageType.GPU_Stack + or output_data.desc(sdfg).storage != + types.StorageType.GPU_Stack): + raise ValueError( + 'Block-wise reduction only supports GPU register inputs ' + 'and outputs') + if redtype in _SPECIAL_RTYPES: + raise ValueError('%s block reduction not supported' % redtype) + + credtype = 'dace::ReductionType::' + str( + redtype)[str(redtype).find('.') + 1:] + if redtype == types.ReductionType.Custom: + redop = '__reduce_%s()' % idstr + else: + redop = 'dace::_wcr_fixed<%s, %s>()' % (credtype, output_type) + + # Allocate shared memory for block reduce + self.scope_entry_stream.write( + """ + typedef cub::BlockReduce<{type}, {numthreads}> BlockReduce_{id}; + __shared__ typename BlockReduce_{id}::TempStorage temp_storage_{id}; + """.format( + id=idstr, + type=output_data.desc(sdfg).dtype.ctype, + numthreads=' * '.join(str(s) for s in self._block_dims)), + sdfg, state_id, node) + + # TODO(later): If less than the whole block is participating, + # use special CUB function + output = cpp_array_expr(sdfg, output_memlet) + callsite_stream.write( + """ + {output} = BlockReduce_{id}(temp_storage_{id}).Reduce({input}, {redop}); + """.format( + id=idstr, + redop=redop, + input=input_memlet.data, + output=output), sdfg, state_id, node) + + return + # Sequential goes to CPU generator + elif node.schedule == types.ScheduleType.Sequential: + self._cpu_codegen._generate_Reduce( + sdfg, dfg, state_id, node, function_stream, callsite_stream) + return + else: + raise ValueError( + 'Unsupported reduction schedule %s' % str(node.schedule)) + + +######################################################################## +######################################################################## +######################################################################## +######################################################################## +# Helper functions and classes + + +def unparse_cr_split(wcr_ast): + """ Parses various types of WCR functions, returning a 3-tuple of body, + first argument name and second argument name. """ + if isinstance(wcr_ast, ast.FunctionDef): + return (cppunparse.cppunparse(wcr_ast.body, expr_semicolon=False), + wcr_ast.args.args[0].arg, wcr_ast.args.args[1].arg) + elif isinstance(wcr_ast, ast.Lambda): + return (('return (' + cppunparse.cppunparse( + wcr_ast.body, expr_semicolon=False) + ');'), + wcr_ast.args.args[0].arg, wcr_ast.args.args[1].arg) + elif isinstance(wcr_ast, ast.Module): + return unparse_cr_split(wcr_ast.body[0].value) + elif isinstance(wcr_ast, str): + return unparse_cr_split(LambdaProperty.from_string(wcr_ast)) + else: + raise NotImplementedError('INVALID TYPE OF WCR: ' + + type(wcr_ast).__name__) + + +def _topy(arr): + """ Converts an array of symbolic variables (or one) to C++ strings. """ + if not isinstance(arr, list): + return cppunparse.pyexpr2cpp(symbolic.symstr(arr)) + return [cppunparse.pyexpr2cpp(symbolic.symstr(d)) for d in arr] + + +def _named_idx(idx): + """ Converts 0 to x, 1 to y, 2 to z, or raises an exception. """ + if idx < 0 or idx > 2: + raise ValueError('idx must be between 0 and 2, got %d' % idx) + return ('x', 'y', 'z')[idx] + + +def _get_storagename(storage): + """ Returns a string containing the name of the storage location. + Example: types.StorageType.GPU_Shared will return "Shared". """ + sname = str(storage) + return sname[sname.rindex('_') + 1:] diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py new file mode 100644 index 0000000000..763fe640ae --- /dev/null +++ b/dace/codegen/targets/framecode.py @@ -0,0 +1,936 @@ +from typing import Set + +import collections +import dace +import functools +from dace.codegen.prettycode import CodeIOStream +from dace.codegen.targets.target import TargetCodeGenerator, TargetDispatcher +from dace.sdfg import SDFG, SDFGState, ScopeSubgraphView +from dace.graph import nodes +from dace import types, config + +from dace.frontend.python import ndarray +from dace.codegen.instrumentation.perfsettings import PerfSettings, PerfUtils +from dace.codegen import cppunparse + +import networkx as nx +import numpy as np + + +class DaCeCodeGenerator(object): + """ DaCe code generator class that writes the generated code for SDFG + state machines, and uses a dispatcher to generate code for + individual states based on the target. """ + + def __init__(self, *args, **kwargs): + self._dispatcher = TargetDispatcher() + self._dispatcher.register_state_dispatcher(self) + self._initcode = CodeIOStream() + self._exitcode = CodeIOStream() + + ################################################################## + # Target registry + + @property + def dispatcher(self): + return self._dispatcher + + ################################################################## + # Code generation + + def generate_constants(self, sdfg: SDFG, callsite_stream: CodeIOStream): + # Write constants + for cstname, cstval in sdfg.constants.items(): + if isinstance(cstval, np.ndarray): + if isinstance(cstval, ndarray.ndarray): + dtype = cstval.descriptor.dtype + else: + dtype = types.typeclass(cstval.dtype.type) + const_str = "constexpr " + dtype.ctype + \ + " " + cstname + "[" + str(cstval.size) + "] = {" + it = np.nditer(cstval, order='C') + for i in range(cstval.size - 1): + const_str += str(it[0]) + ", " + it.iternext() + const_str += str(it[0]) + "};\n" + callsite_stream.write(const_str, sdfg) + else: + callsite_stream.write( + "constexpr auto %s = %s;\n" % (cstname, str(cstval)), sdfg) + + def generate_fileheader(self, sdfg: SDFG, global_stream: CodeIOStream): + """ Generate a header in every output file that includes custom types + and constants. + @param sdfg: The input SDFG. + @param global_stream: Stream to write to (global). + """ + ######################################################### + # Custom types + types = set() + # Types of this SDFG + for sdfg, arrname, arr in sdfg.arrays_recursive(): + if arr is not None: + types.add(arr.dtype) + + # Emit unique definitions + global_stream.write('\n') + for typ in types: + if hasattr(typ, 'emit_definition'): + global_stream.write(typ.emit_definition(), sdfg) + global_stream.write('\n') + + ######################################################### + # Write constants + self.generate_constants(sdfg, global_stream) + + def generate_header(self, sdfg: SDFG, global_stream: CodeIOStream, + callsite_stream: CodeIOStream): + """ Generate the header of the frame-code. Code exists in a separate + function for overriding purposes. + @param sdfg: The input SDFG. + @param global_stream: Stream to write to (global). + @param callsite_stream: Stream to write to (at call site). + """ + fname = sdfg.name + params = sdfg.signature() + + # Write frame code - header + global_stream.write( + '/* DaCe AUTO-GENERATED FILE. DO NOT MODIFY */\n' + + '#include \n', sdfg) + + # Added for instrumentation includes + if PerfSettings.perf_enable_instrumentation(): + global_stream.write( + '/* DaCe instrumentation include */\n' + + '#include \n', sdfg) + + self.generate_fileheader(sdfg, callsite_stream) + + callsite_stream.write( + 'void __program_%s_internal(%s)\n{\n' % (fname, params), sdfg) + + # Define the performance store (autocleanup on destruction) + if PerfSettings.perf_enable_instrumentation(): + callsite_stream.write( + 'dace_perf::PAPI::init();\n' + 'dace_perf::%s __perf_store;\n' + % PerfUtils.perf_counter_store_string( + PerfSettings.perf_default_papi_counters()), sdfg) + + def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, + callsite_stream: CodeIOStream): + """ Generate the footer of the frame-code. Code exists in a separate + function for overriding purposes. + @param sdfg: The input SDFG. + @param global_stream: Stream to write to (global). + @param callsite_stream: Stream to write to (at call site). + """ + fname = sdfg.name + params = sdfg.signature() + paramnames = sdfg.signature(False) + + # Write frame code - footer + callsite_stream.write('}\n', sdfg) + + # Write awkward footer to avoid 'extern "C"' issues + callsite_stream.write( + """ +void __program_%s_internal(%s); +DACE_EXPORTED void __program_%s(%s) +{ + __program_%s_internal(%s); +} +""" % (fname, params, fname, params, fname, paramnames), sdfg) + + for target in self._dispatcher.used_targets: + if target.has_initializer: + callsite_stream.write( + 'DACE_EXPORTED int __dace_init_%s(%s);\n' % + (target.target_name, params), sdfg) + if target.has_finalizer: + callsite_stream.write( + 'DACE_EXPORTED int __dace_exit_%s(%s);\n' % + (target.target_name, params), sdfg) + + callsite_stream.write( + """ +DACE_EXPORTED int __dace_init(%s) +{ + int result = 0; +""" % params, sdfg) + + for target in self._dispatcher.used_targets: + if target.has_initializer: + callsite_stream.write( + 'result |= __dace_init_%s(%s);' % (target.target_name, + paramnames), sdfg) + + callsite_stream.write(self._initcode.getvalue(), sdfg) + + callsite_stream.write( + """ + return result; +} + +DACE_EXPORTED void __dace_exit(%s) +{ +""" % params, sdfg) + + callsite_stream.write(self._exitcode.getvalue(), sdfg) + + for target in self._dispatcher.used_targets: + if target.has_finalizer: + callsite_stream.write( + '__dace_exit_%s(%s);' % (target.target_name, paramnames), + sdfg) + + callsite_stream.write('}\n', sdfg) + + def generate_state(self, + sdfg, + state, + global_stream, + callsite_stream, + generate_state_footer=True): + + sid = sdfg.node_id(state) + + # Emit internal transient array allocation + # Don't allocate transients shared with another state + data_to_allocate = ( + set(state.top_level_transients()) - set(sdfg.shared_transients())) + allocated = set() + for node in state.data_nodes(): + if node.data not in data_to_allocate or node.data in allocated: + continue + allocated.add(node.data) + self._dispatcher.dispatch_allocate(sdfg, state, sid, node, + global_stream, callsite_stream) + self._dispatcher.dispatch_initialize( + sdfg, state, sid, node, global_stream, callsite_stream) + + ##################### + # Create dataflow graph for state's children. + + # DFG to code scheme: Only generate code for nodes whose all + # dependencies have been executed (topological sort). + # For different connected components, run them concurrently. + + components = dace.sdfg.concurrent_subgraphs(state) + + if len(components) == 1: + self._dispatcher.dispatch_subgraph( + sdfg, + state, + sid, + global_stream, + callsite_stream, + skip_entry_node=False) + else: + ############################################################# + # Instrumentation: Pre-state + # We cannot have supersections starting in parallel + parent_id = PerfUtils.unified_id(-1, sid) + if PerfSettings.perf_enable_instrumentation(): + callsite_stream.write( + "__perf_store.markSuperSectionStart(%d);\n" % + PerfUtils.unified_id(-1, sid)) + ############################################################# + + callsite_stream.write("#pragma omp parallel sections\n{") + for c in components: + c.set_parallel_parent( + parent_id + ) # Keep in mind not to add supersection start markers! + callsite_stream.write("#pragma omp section\n{") + self._dispatcher.dispatch_subgraph( + sdfg, + c, + sid, + global_stream, + callsite_stream, + skip_entry_node=False) + callsite_stream.write("} // End omp section") + callsite_stream.write("} // End omp sections") + + ##################### + # Write state footer + + if generate_state_footer: + # Emit internal transient array deallocation + deallocated = set() + for node in state.data_nodes(): + if node.data not in data_to_allocate or node.data in deallocated: + continue + deallocated.add(node.data) + self._dispatcher.dispatch_deallocate( + sdfg, state, sid, node, global_stream, callsite_stream) + + @staticmethod + def _generate_assignments(assignments): + return [ + "{} = {}".format(variable, value) + for variable, value in assignments.items() + ] + + @staticmethod + def _is_always_true(condition_string): + return condition_string in ["true", "1"] + + def _generate_transition(self, sdfg, sid, callsite_stream, edge, + assignments): + + condition_string = cppunparse.cppunparse(edge.data.condition, False) + always_true = self._is_always_true(condition_string) + + if not always_true: + callsite_stream.write("if ({}) {{".format(condition_string), sdfg, + sid) + + if len(assignments) > 0: + callsite_stream.write( + ";\n".join( + DaCeCodeGenerator._generate_assignments(assignments) + + [""]), sdfg, sid) + + callsite_stream.write( + "goto __state_{}_{};".format(sdfg.name, edge.dst.label), sdfg, sid) + + if not always_true: + callsite_stream.write("}") + + def generate_states(self, sdfg, scope_label, control_flow, global_stream, + callsite_stream, scope, states_generated): + + states_topological = list(sdfg.topological_sort(sdfg.start_state)) + states_to_generate = collections.deque([ + s for s in states_topological + if s in scope and s not in states_generated + ]) + if len(states_to_generate) == 0: + return + + while len(states_to_generate) > 0: + + state = states_to_generate.popleft() + # When generating control flow constructs, we will not necessarily + # move in topological order, so make sure this state has not + # already been generated. + if state in states_generated or state not in scope: + continue + states_generated.add(state) + + sid = sdfg.node_id(state) + + callsite_stream.write( + "__state_{}_{}:\n".format(sdfg.name, state.label), sdfg, sid) + + # Don't generate brackets and comments for empty states + if len([ + n for n in state.nodes() + if not isinstance(n, dace.graph.nodes.EmptyTasklet) + ]) > 0: + + callsite_stream.write('{', sdfg, sid) + + self._dispatcher.dispatch_state(sdfg, state, global_stream, + callsite_stream) + + callsite_stream.write('}', sdfg, sid) + + else: + + callsite_stream.write(";") + + out_edges = sdfg.out_edges(state) + + # Write conditional branches to next states + for edge in out_edges: + + generate_assignments = True + generate_transition = True + + # Handle specialized control flow + if (dace.config.Config.get_bool('optimizer', + 'detect_control_flow')): + + for control in control_flow[edge]: + + if isinstance(control, + dace.graph.edges.LoopAssignment): + # Generate the transition, but leave the + # assignments to the loop + generate_transition = True + generate_assignments = False + + elif isinstance(control, dace.graph.edges.LoopBack): + generate_transition = False + generate_assignments = False + + elif isinstance(control, dace.graph.edges.LoopExit): + # Need to strip the condition, so generate it from + # the loop entry + generate_transition = False + generate_assignments = True + pass + + elif isinstance(control, dace.graph.edges.LoopEntry): + generate_transition = False + generate_assignments = False + + if control.scope.assignment is not None: + assignment_edge = control.scope.assignment.edge + init_assignments = ", ".join( + DaCeCodeGenerator._generate_assignments( + assignment_edge.data.assignments)) + else: + init_assignments = "" + + back_edge = control.scope.back.edge + continue_assignments = ", ".join( + DaCeCodeGenerator._generate_assignments( + back_edge.data.assignments)) + + entry_edge = control.scope.entry.edge + condition = cppunparse.cppunparse( + entry_edge.data.condition, False) + + if (len(init_assignments) > 0 + or len(continue_assignments) > 0): + callsite_stream.write( + "for ({}; {}; {}) {{".format( + init_assignments, condition, + continue_assignments), sdfg, sid) + else: + callsite_stream.write( + "while ({}) {{".format(condition), sdfg, + sid) + + # Generate loop body + self.generate_states( + sdfg, entry_edge.src.label + "_loop", + control_flow, global_stream, callsite_stream, + control.scope, states_generated) + + callsite_stream.write("}", sdfg, sid) + + exit_edge = control.scope.exit.edge + + # Update states to generate after nested call + states_to_generate = collections.deque([ + s for s in states_to_generate + if s not in states_generated + ]) + # If the next state to be generated is the exit + # state, we can omit the goto + if (len(states_to_generate) > 0 + and states_to_generate[0] == exit_edge.dst + and exit_edge.dst not in states_generated): + pass + else: + callsite_stream.write( + "goto __state_{}_{};".format( + sdfg.name, + control.scope.exit.edge.dst)) + + elif isinstance(control, dace.graph.edges.IfExit): + generate_transition = True + generate_assignments = True + + elif isinstance(control, dace.graph.edges.IfEntry): + generate_transition = False + generate_assignments = True + + if len(set(control.scope) - states_generated) == 0: + continue + + then_scope = control.scope.if_then_else.then_scope + else_scope = control.scope.if_then_else.else_scope + + then_entry = then_scope.entry.edge + + condition = cppunparse.cppunparse( + then_entry.data.condition, False) + + callsite_stream.write( + "if ({}) {{".format(condition), sdfg, sid) + + # Generate the then-scope + self.generate_states(sdfg, state.label + "_then", + control_flow, global_stream, + callsite_stream, then_scope, + states_generated) + + callsite_stream.write("} else {", sdfg, sid) + + # Generate the else-scope + self.generate_states(sdfg, state.label + "_else", + control_flow, global_stream, + callsite_stream, else_scope, + states_generated) + + callsite_stream.write("}", sdfg, sid) + + # Update states to generate after nested call + states_to_generate = collections.deque([ + s for s in states_to_generate + if s not in states_generated + ]) + + if_exit_state = control.scope.exit.edge.dst + + if ((if_exit_state not in states_generated) and + ((len(states_to_generate) > 0) and + (states_to_generate[0] == if_exit_state))): + pass + else: + callsite_stream.write( + "goto __state_{}_{};".format( + sdfg.name, + control.scope.exit.edge.dst)) + + else: + + raise TypeError( + "Unknown control flow \"{}\"".format( + type(control).__name__)) + + if generate_assignments and len(edge.data.assignments) > 0: + assignments_to_generate = edge.data.assignments + else: + assignments_to_generate = {} + + if generate_transition: + + if ((len(out_edges) == 1) + and (edge.dst not in states_generated) + and ((len(states_to_generate) > 0) and + (states_to_generate[0] == edge.dst))): + # If there is only one outgoing edge, the target will + # be generated next, we can omit the goto + pass + elif (len(out_edges) == 1 and len(states_to_generate) == 0 + and (edge.dst not in scope)): + # This scope has ended, and we don't need to generate + # any output edge + pass + else: + self._generate_transition(sdfg, sid, callsite_stream, + edge, + assignments_to_generate) + # Assignments will be generated in the transition + generate_assignments = False + + if generate_assignments: + + callsite_stream.write( + ";\n".join( + DaCeCodeGenerator._generate_assignments( + assignments_to_generate) + [""]), sdfg, sid) + + if (((len(out_edges) == 0) or + (not isinstance(scope, dace.graph.edges.ControlFlowScope) and + (len(states_to_generate) == 0))) + and (len(states_generated) != sdfg.number_of_nodes())): + callsite_stream.write( + "goto __state_exit_{}_{};".format(sdfg.name, scope_label), + sdfg, sid) + + # Write exit state + callsite_stream.write( + "__state_exit_{}_{}:;".format(sdfg.name, scope_label), sdfg) + + @staticmethod + def all_nodes_between(graph, begin, end): + """Finds all nodes between begin and end. Returns None if there is any + path starting at begin that does not reach end.""" + to_visit = [begin] + seen = set() + while len(to_visit) > 0: + n = to_visit.pop() + if n == end: + continue # We've reached the end node + if n in seen: + continue # We've already visited this node + seen.add(n) + # Keep chasing all paths to reach the end node + node_out_edges = graph.out_edges(n) + if len(node_out_edges) == 0: + # We traversed to the end without finding the end + return None + for e in node_out_edges: + next_node = e.dst + if next_node != end and next_node not in seen: + to_visit.append(next_node) + return seen + + def generate_code(self, + sdfg: SDFG, + schedule: types.ScheduleType, + sdfg_id: str = "" + ) -> (str, str, Set[TargetCodeGenerator]): + """ Generate frame code for a given SDFG, calling registered targets' + code generation callbacks for them to generate their own code. + @param sdfg: The SDFG to generate code for. + @param schedule: The schedule the SDFG is currently located, or + None if the SDFG is top-level. + @param sdfg_id: An optional string id given to the SDFG label + @return: A tuple of the generated global frame code, local frame + code, and a set of targets that have been used in the + generation of this SDFG. + """ + + sdfg_label = sdfg.name + sdfg_id + + global_stream = CodeIOStream() + callsite_stream = CodeIOStream() + + # Set default storage/schedule types in SDFG + _set_default_schedule_and_storage_types(sdfg, schedule) + + # Generate preamble (if top-level) + if schedule is None: + self.generate_header(sdfg, global_stream, callsite_stream) + + # Generate code + ########################### + + if sdfg.parent is not None: + # Nested SDFG + symbols_available = sdfg.parent.symbols_defined_at(sdfg) + else: + symbols_available = sdfg.constants + + # Allocate outer-level transients + shared_transients = sdfg.shared_transients() + allocated = set() + for state in sdfg.nodes(): + for node in state.data_nodes(): + if (node.data in shared_transients + and node.data not in allocated): + self._dispatcher.dispatch_allocate(sdfg, state, None, node, + global_stream, + callsite_stream) + self._dispatcher.dispatch_initialize( + sdfg, state, None, node, global_stream, + callsite_stream) + allocated.add(node.data) + + # Allocate inter-state variables + assigned, _ = sdfg.interstate_symbols() + for isvarName, isvarType in assigned.items(): + # Skip symbols that have been declared as outer-level transients + if isvarName in allocated: + continue + callsite_stream.write( + '%s;\n' % (isvarType.signature( + with_types=True, name=isvarName)), sdfg) + + # Initialize parameter arrays + for argnode in types.deduplicate(sdfg.input_arrays() + + sdfg.output_arrays()): + # Ignore transient arrays + if argnode.desc(sdfg).transient: continue + self._dispatcher.dispatch_initialize( + sdfg, sdfg, None, argnode, global_stream, callsite_stream) + + callsite_stream.write('\n', sdfg) + + states_topological = list(sdfg.topological_sort(sdfg.start_state)) + + # {edge: [dace.edges.ControlFlow]} + control_flow = {e: [] for e in sdfg.edges()} + + if dace.config.Config.get_bool('optimizer', 'detect_control_flow'): + + #################################################################### + # Loop detection procedure + + all_cycles = list(sdfg.find_cycles()) # Returns a list of lists + # Order according to topological sort + all_cycles = [ + sorted(c, key=lambda x: states_topological.index(x)) + for c in all_cycles + ] + # Group in terms of starting node + starting_nodes = [c[0] for c in all_cycles] + cycles_by_node = [[c for c in all_cycles if c[0] == n] + for n in starting_nodes] + for cycles in cycles_by_node: + + # Use arbitrary cycle to find the first and last nodes + first_node = cycles[0][0] + last_node = cycles[0][-1] + + if not first_node.is_empty(): + # The entry node should not contain any computations + continue + + if not all([c[-1] == last_node for c in cycles]): + # There are multiple back edges: not a for or while loop + continue + + previous_edge = [ + e for e in sdfg.in_edges(first_node) if e.src != last_node + ] + if len(previous_edge) != 1: + # No single starting point: not a for or while + continue + previous_edge = previous_edge[0] + + back_edge = sdfg.edges_between(last_node, first_node) + if len(back_edge) != 1: + raise RuntimeError("Expected exactly one edge in cycle") + back_edge = back_edge[0] + + # Build a set of all nodes in all cycles associated with this + # set of start and end node + internal_nodes = functools.reduce( + lambda a, b: a | b, [set(c) + for c in cycles]) - {first_node} + + exit_edge = [ + e for e in sdfg.out_edges(first_node) + if e.dst not in internal_nodes | {first_node} + ] + if len(exit_edge) != 1: + # No single stopping condition: not a for or while + # (we don't support continue or break) + continue + exit_edge = exit_edge[0] + + entry_edge = [ + e for e in sdfg.out_edges(first_node) if e != exit_edge + ] + if len(entry_edge) != 1: + # No single starting condition: not a for or while + continue + entry_edge = entry_edge[0] + + # Make sure this is not already annotated to be another construct + if (len(control_flow[entry_edge]) != 0 + or len(control_flow[back_edge]) != 0 + or len(control_flow[exit_edge]) != 0): + continue + + if entry_edge == back_edge: + # No entry check (we don't support do-loops) + # TODO: do we want to add some support for self-loops? + continue + + # Now we make sure that there is no other way to exit this + # cycle, by checking that there's no reachable node *not* + # included in any cycle between the first and last node. + if any([len(set(c) - internal_nodes) > 1 for c in cycles]): + continue + + # This is a loop! Generate the necessary annotation objects. + loop_scope = dace.graph.edges.LoopScope(internal_nodes) + + if ((len(previous_edge.data.assignments) > 0 + or len(back_edge.data.assignments) > 0) + and len(control_flow[previous_edge]) == 0): + # Generate assignment edge, if available + control_flow[previous_edge].append( + dace.graph.edges.LoopAssignment( + loop_scope, previous_edge)) + # Assign remaining control flow constructs + control_flow[entry_edge].append( + dace.graph.edges.LoopEntry(loop_scope, entry_edge)) + control_flow[exit_edge].append( + dace.graph.edges.LoopExit(loop_scope, exit_edge)) + control_flow[back_edge].append( + dace.graph.edges.LoopBack(loop_scope, back_edge)) + + ################################################################### + # If/then/else detection procedure + + candidates = [ + n for n in states_topological if sdfg.out_degree(n) == 2 + ] + for candidate in candidates: + + # A valid if occurs when then are no reachable nodes for either + # path that does not pass through a common dominator. + dominators = nx.dominance.dominance_frontiers( + sdfg.nx, candidate) + + left_entry, right_entry = sdfg.out_edges(candidate) + if (len(control_flow[left_entry]) > 0 + or len(control_flow[right_entry]) > 0): + # Already assigned to a control flow construct + # TODO: carefully allow this in some cases + continue + + left, right = left_entry.dst, right_entry.dst + dominator = dominators[left] & dominators[right] + if len(dominator) != 1: + # There must be a single dominator across both branches, + # unless one of the nodes _is_ the next dominator + # if (len(dominator) == 0 and dominators[left] == {right} + # or dominators[right] == {left}): + # dominator = dominators[left] | dominators[right] + # else: + # continue + continue + dominator = next(iter(dominator)) # Exactly one dominator + + exit_edges = sdfg.in_edges(dominator) + if len(exit_edges) != 2: + # There must be a single entry and a single exit. This + # could be relaxed in the future. + continue + + left_exit, right_exit = exit_edges + if (len(control_flow[left_exit]) > 0 + or len(control_flow[right_exit]) > 0): + # Already assigned to a control flow construct + # TODO: carefully allow this in some cases + continue + + # Now traverse from the source and verify that all possible paths + # pass through the dominator + left_nodes = DaCeCodeGenerator.all_nodes_between( + sdfg, left, dominator) + if left_nodes is None: + # Not all paths lead to the next dominator + continue + right_nodes = DaCeCodeGenerator.all_nodes_between( + sdfg, right, dominator) + if right_nodes is None: + # Not all paths lead to the next dominator + continue + all_nodes = left_nodes | right_nodes + + # Make sure there is no overlap between left and right nodes + if len(left_nodes & right_nodes) > 0: + continue + + # This is a valid if/then/else construct. Generate annotations + if_then_else = dace.graph.edges.IfThenElse( + candidate, dominator) + + # Arbitrarily assign then/else to the two branches. If one edge + # has no dominator but leads to the dominator, it means there's + # only a then clause (and no else). + has_else = False + if len(dominators[left]) == 1: + then_scope = dace.graph.edges.IfThenScope( + if_then_else, left_nodes) + else_scope = dace.graph.edges.IfElseScope( + if_then_else, right_nodes) + control_flow[left_entry].append( + dace.graph.edges.IfEntry(then_scope, left_entry)) + control_flow[left_exit].append( + dace.graph.edges.IfExit(then_scope, left_exit)) + control_flow[right_exit].append( + dace.graph.edges.IfExit(else_scope, right_exit)) + if len(dominators[right]) == 1: + control_flow[right_entry].append( + dace.graph.edges.IfEntry(else_scope, right_entry)) + has_else = True + else: + then_scope = dace.graph.edges.IfThenScope( + if_then_else, right_nodes) + else_scope = dace.graph.edges.IfElseScope( + if_then_else, left_nodes) + control_flow[right_entry].append( + dace.graph.edges.IfEntry(then_scope, right_entry)) + control_flow[right_exit].append( + dace.graph.edges.IfExit(then_scope, right_exit)) + control_flow[left_exit].append( + dace.graph.edges.IfExit(else_scope, left_exit)) + + ####################################################################### + # State transition generation + + states_generated = set() # For sanity check + self.generate_states(sdfg, "sdfg", control_flow, + global_stream, callsite_stream, + set(states_topological), states_generated) + + ############################# + # End of code generation + + if len(states_generated) != len(sdfg.nodes()): + raise RuntimeError( + "Not all states were generated in SDFG {}!" + "\n Generated: {}\n Missing: {}".format( + sdfg.label, [s.label for s in states_generated], + [s.label for s in (set(sdfg.nodes()) - states_generated)])) + + # Deallocate transients + shared_transients = sdfg.shared_transients() + deallocated = set() + for state in sdfg.nodes(): + for node in state.data_nodes(): + if (node.data in shared_transients + and node.data not in deallocated): + self._dispatcher.dispatch_deallocate( + sdfg, sdfg, None, node, global_stream, callsite_stream) + deallocated.add(node.data) + + ########################### + + # Generate footer (if top-level) + if schedule is None: + self.generate_footer(sdfg, global_stream, callsite_stream) + + # Clear out all the annotated control flow + + # Return the generated global and local code strings + return (global_stream.getvalue(), callsite_stream.getvalue(), + self._dispatcher.used_targets) + + +def _set_default_schedule_and_storage_types(sdfg, toplevel_schedule): + """ Sets default storage and schedule types throughout SDFG. + Replaces `ScheduleType.Default` and `StorageType.Default` + with the corresponding types according to the parent scope's + schedule. """ + for state in sdfg.nodes(): + scope_dict = state.scope_dict() + reverse_scope_dict = state.scope_dict(node_to_children=True) + + def set_default_in_scope(parent_node): + if parent_node is None: + parent_schedule = toplevel_schedule + else: + parent_schedule = parent_node.map.schedule + + for node in reverse_scope_dict[parent_node]: + # Set default schedule type + if isinstance(node, nodes.MapEntry): + if node.map.schedule == types.ScheduleType.Default: + node.map._schedule = \ + types.SCOPEDEFAULT_SCHEDULE[parent_schedule] + # Also traverse children (recursively) + set_default_in_scope(node) + elif isinstance(node, nodes.ConsumeEntry): + if node.consume.schedule == types.ScheduleType.Default: + node.consume._schedule = \ + types.SCOPEDEFAULT_SCHEDULE[parent_schedule] + # Also traverse children (recursively) + set_default_in_scope(node) + elif getattr(node, 'schedule', False): + if node.schedule == types.ScheduleType.Default: + node._schedule = \ + types.SCOPEDEFAULT_SCHEDULE[parent_schedule] + + ## End of recursive function + + # Start with top-level nodes + set_default_in_scope(None) + + # Set default storage type + for node in state.nodes(): + if isinstance(node, nodes.AccessNode): + if node.desc(sdfg).storage == types.StorageType.Default: + if scope_dict[node] is None: + parent_schedule = toplevel_schedule + else: + parent_schedule = scope_dict[node].map.schedule + + node.desc(sdfg).storage = ( + types.SCOPEDEFAULT_STORAGE[parent_schedule]) + ### End of storage type loop diff --git a/dace/codegen/targets/immaterial.py b/dace/codegen/targets/immaterial.py new file mode 100644 index 0000000000..8290b4badc --- /dev/null +++ b/dace/codegen/targets/immaterial.py @@ -0,0 +1,238 @@ +from dace import data, subsets, symbolic, types +from dace.codegen.codeobject import CodeObject +from dace.codegen.targets.target import TargetCodeGenerator +from dace.codegen.targets.cpu import cpp_array_expr, sym2cpp +from dace.graph import nodes + +from dace.codegen import cppunparse + + +class ImmaterialCodeGen(TargetCodeGenerator): + """ Code generator for data nodes with immaterial (i.e., generated + from a function) storage. """ + + target_name = 'Immaterial' + language = 'cpp' + + def __init__(self, frame_codegen, sdfg): + self._frame = frame_codegen + self._dispatcher = frame_codegen.dispatcher + dispatcher = self._dispatcher + + self.emitted_materialize_funcs = set() + + # Register dispatchers + dispatcher.register_array_dispatcher(types.StorageType.Immaterial, + self) + + cpu_storage = [ + types.StorageType.CPU_Heap, types.StorageType.CPU_Pinned, + types.StorageType.CPU_Stack, types.StorageType.Register + ] + for storage_type in cpu_storage: + dispatcher.register_copy_dispatcher(types.StorageType.Immaterial, + storage_type, None, self) + dispatcher.register_copy_dispatcher( + storage_type, types.StorageType.Immaterial, None, self) + + def get_generated_codeobjects(self): + return [] # Immaterial storage generates inline code + + @property + def has_initializer(self): + return False + + @property + def has_finalizer(self): + return False + + def allocate_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + callsite_stream.write("// allocate array\n", sdfg, state_id, node) + + def initialize_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + callsite_stream.write("// initialize_array " + node.data + "\n", sdfg, + state_id, node) + + def deallocate_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + callsite_stream.write("// deallocate_array", sdfg, state_id, node) + + def copy_memory(self, sdfg, dfg, state_id, src_node, dst_node, edge, + function_stream, callsite_stream): + memlet = edge.data + if (isinstance(src_node, nodes.AccessNode) + and (src_node.desc(sdfg).materialize_func is not None)): + function_stream.write(src_node.desc(sdfg).materialize_func) + + if edge.dst_conn is not None: + arrayname = str(edge.dst_conn) + else: + arrayname = str(dst_node.desc) + + if isinstance(dst_node, nodes.Tasklet) or \ + (dst_node.desc(sdfg).storage == types.StorageType.Register): + callsite_stream.write( + self.memlet_definition( + sdfg, memlet, arrayname, direction="in"), sdfg, + state_id, [src_node, dst_node]) + else: + callsite_stream.write("__dace_materialize(\"" + \ + sym2cpp(src_node) + "\", " + \ + sym2cpp(memlet.subset.min_element()[0]) + + ", " + \ + sym2cpp(memlet.subset.min_element()[0] + + memlet.subset.num_elements()) + + ", " + sym2cpp(dst_node.data) + ");\n", + sdfg, state_id, [src_node, dst_node]) + + if (isinstance(dst_node, nodes.AccessNode) + and (dst_node.desc(sdfg).materialize_func is not None)): + # This case is pretty complicated due to how the rest of the + # codegen works: This is not the place to actually copy code. In + # the place where data is ready to be written there will be a call + # __foo.write(foo) where foo is the local_name of the memlet that + # "causes" the write. But this function is actually called when + # we should set up everything for this call to work. + # The above mentioned code is generated by process_out_memlets + + function_stream.write(dst_node.desc(sdfg).materialize_func) + if isinstance(src_node, nodes.Tasklet) or \ + (src_node.desc(sdfg).storage == types.StorageType.Register): + callsite_stream.write( + self.memlet_definition( + sdfg, memlet, edge.src_conn, direction="out"), sdfg, + state_id, [src_node, dst_node]) + else: + callsite_stream.write("__dace_serialize(\"" + \ + sym2cpp(dst_node) + "\", " + \ + sym2cpp(memlet.subset.min_element()[0]) + + ", " + \ + sym2cpp(memlet.subset.min_element()[0] + + memlet.subset.num_elements()) + + ", " + sym2cpp(src_node.data) + ");\n", + sdfg, state_id, [src_node, dst_node]) + + def memlet_definition(self, sdfg, memlet, local_name, direction="in"): + if isinstance(memlet.data, data.Stream): + return 'auto& %s = %s;\n' % (local_name, memlet.data) + + result = ('auto __%s = ' % local_name + self.memlet_view_ctor( + sdfg, memlet, direction) + ';\n') + + # Allocate variable type + memlet_type = ' dace::vec<%s, %s>' % ( + sdfg.arrays[memlet.data].dtype.ctype, sym2cpp(memlet.veclen)) + if memlet.subset.data_dims() == 0 and memlet.num_accesses >= 0: + result += memlet_type + ' ' + local_name + if direction == "in": + result += ' = __%s;\n' % local_name + else: + result += ';\n' + + return result + + def memlet_view_ctor(self, sdfg, memlet, direction): + useskip = False + memlet_params = [] + + memlet_name = memlet.data + if isinstance(sdfg.arrays[memlet.data], data.Scalar): + raise ValueError("This should never have happened") + + if isinstance(memlet.subset, subsets.Indices): + # Compute address + memlet_params.append(cpp_array_expr(sdfg, memlet, False)) + dims = 0 + + elif isinstance(memlet.subset, subsets.Range): + dims = len(memlet.subset.ranges) + #memlet_params.append("") + + # Dimensions to remove from view (due to having one value) + indexdims = [] + nonIndexDims = [] + + for dim, (rb, re, rs) in enumerate(memlet.subset.ranges): + if rs != 1: + useskip = True + try: + if (re - rb) == 0: + indexdims.append(dim) + else: + nonIndexDims.append(dim) + except TypeError: # cannot determine truth value of Relational + nonIndexDims.append(dim) + + if len(nonIndexDims) > 1 and len(indexdims) > 0: + raise NotImplementedError( + 'subviews of more than one dimension ' + 'not implemented') + elif len( + nonIndexDims) == 1 and len(indexdims) > 0: # One dimension + indexdim = nonIndexDims[0] + + # Contiguous dimension + if indexdim == dims - 1: + memlet_params[-1] += ' + %s' % cpp_array_expr( + sdfg, memlet, False) + memlet_params.append( + '0, %s' % (sym2cpp(memlet.subset.ranges[-1][1] - + memlet.subset.ranges[-1][0]))) + else: # Non-contiguous dimension + useskip = True + memlet_params[-1] += ' + %s' % cpp_array_expr( + sdfg, memlet, False) + memlet_range = memlet.subset.ranges[indexdim] + + # TODO(later): Access order + memlet_stride = functools.reduce( + lambda x, y: x * y, + sdfg.arrays[memlet.data].shape[indexdim + 1:]) + memlet_stride = sym2cpp(memlet_stride) + + memlet_params.append( + '0, %s, %s' % + (sym2cpp(memlet_range[1] - memlet_range[0]), + sym2cpp(memlet_stride))) + + # Subtract index dimensions from array dimensions + dims -= len(indexdims) + + elif len(indexdims) == 0: + for (rb, re, rs), s in zip(memlet.subset.ranges, + sdfg.arrays[memlet.data].shape): + if useskip: + memlet_params.append( + '%s, %s, %s' % + (cppunparse.pyexpr2cpp(symbolic.symstr(rb)), + cppunparse.pyexpr2cpp(symbolic.symstr(s)), + cppunparse.pyexpr2cpp(symbolic.symstr(rs)))) + else: + memlet_params.append( + '%s, %s' % + (cppunparse.pyexpr2cpp(symbolic.symstr(rb)), + cppunparse.pyexpr2cpp(symbolic.symstr(s)))) + elif len(nonIndexDims) == 0: # Scalar view + # Compute address + memlet_params[-1] += ' + ' + cpp_array_expr( + sdfg, memlet, False) + dims = 0 + + else: + raise RuntimeError( + 'Memlet type "%s" not implemented' % memlet.subset) + + if dims == 0: + return 'dace::ArrayViewImmaterial%s%s<%s, %s, int32_t> ("%s", %s)' % ( + 'In' if direction == "in" else "Out", 'Skip' + if useskip else '', sdfg.arrays[memlet.data].dtype.ctype, + symbolic.symstr( + memlet.veclen), memlet.data, ', '.join(memlet_params)) + else: + return 'dace::ArrayViewImmaterial%s%s<%s, %s, int32_t, %s> ("%s", %s)' % ( + 'In' if direction == "in" else "Out", 'Skip' + if useskip else '', sdfg.arrays[memlet.data].dtype.ctype, + symbolic.symstr(memlet.veclen), ', '.join([ + str(s) for s in memlet.subset.bounding_box_size() + ]), memlet.data, ', '.join(memlet_params)) diff --git a/dace/codegen/targets/mpi.py b/dace/codegen/targets/mpi.py new file mode 100644 index 0000000000..72356a7463 --- /dev/null +++ b/dace/codegen/targets/mpi.py @@ -0,0 +1,129 @@ +import dace +from dace import symbolic, types +from dace.codegen.prettycode import CodeIOStream +from dace.codegen.codeobject import CodeObject +from dace.codegen.targets.target import TargetCodeGenerator, make_absolute +from dace.graph import nodes +from dace.config import Config + +from dace.codegen import cppunparse + + +class MPICodeGen(TargetCodeGenerator): + """ An MPI code generator. """ + target_name = 'mpi' + title = 'MPI' + language = 'cpp' + + def __init__(self, frame_codegen, sdfg): + self._frame = frame_codegen + self._dispatcher = frame_codegen.dispatcher + dispatcher = self._dispatcher + + fileheader = CodeIOStream() + self._frame.generate_fileheader(sdfg, fileheader) + + self._codeobj = CodeObject( + sdfg.name + '_mpi', """ +#include +#include + +MPI_Comm __dace_mpi_comm; +int __dace_comm_size = 1; +int __dace_comm_rank = 0; + +{file_header} + +DACE_EXPORTED int __dace_init_mpi({params}); +DACE_EXPORTED void __dace_exit_mpi({params}); + +int __dace_init_mpi({params}) {{ + if (MPI_Init(NULL, NULL) != MPI_SUCCESS) + return 1; + + MPI_Comm_dup(MPI_COMM_WORLD, &__dace_mpi_comm); + MPI_Comm_rank(__dace_mpi_comm, &__dace_comm_rank); + MPI_Comm_size(__dace_mpi_comm, &__dace_comm_size); + + printf(\"MPI was initialized on proc %i of %i\\n\", __dace_comm_rank, + __dace_comm_size); + return 0; +}} + +void __dace_exit_mpi({params}) {{ + MPI_Comm_free(&__dace_mpi_comm); + MPI_Finalize(); + + printf(\"MPI was finalized on proc %i of %i\\n\", __dace_comm_rank, + __dace_comm_size); +}} +""".format(params=sdfg.signature(), file_header=fileheader.getvalue()), 'cpp', + MPICodeGen, 'MPI') + + # Register dispatchers + dispatcher.register_map_dispatcher(types.ScheduleType.MPI, self) + + def get_generated_codeobjects(self): + return [self._codeobj] + + @staticmethod + def cmake_options(): + compiler = make_absolute(Config.get("compiler", "mpi", "executable")) + return [ + "-DMPI_CXX_COMPILER=\"{}\"".format(compiler), + "-DDACE_ENABLE_MPI=ON", + ] + + @property + def has_initializer(self): + return True + + @property + def has_finalizer(self): + return True + + def generate_scope(self, sdfg, dfg_scope, state_id, function_stream, + callsite_stream): + # Take care of map header + assert len(dfg_scope.source_nodes()) == 1 + map_header = dfg_scope.source_nodes()[0] + + function_stream.write('extern int __dace_comm_size, __dace_comm_rank;', + sdfg, state_id, map_header) + + if len(map_header.map.params) > 1: + raise NotImplementedError( + 'Multi-dimensional MPI maps are not supported') + + for var, r in zip(map_header.map.params, map_header.map.range): + begin, end, skip = r + + callsite_stream.write('{\n', sdfg, state_id, map_header) + callsite_stream.write( + 'auto %s = %s + __dace_comm_rank * (%s);\n' % + (var, cppunparse.pyexpr2cpp(symbolic.symstr(begin)), + cppunparse.pyexpr2cpp(symbolic.symstr(skip))), sdfg, state_id, + map_header) + + to_allocate = dace.sdfg.local_transients(sdfg, dfg_scope, map_header) + allocated = set() + for child in dfg_scope.scope_dict(node_to_children=True)[map_header]: + if not isinstance(child, nodes.AccessNode): + continue + if child.data not in to_allocate or child.data in allocated: + continue + allocated.add(child.data) + self._dispatcher.dispatch_allocate(sdfg, dfg_scope, state_id, + child, function_stream, + callsite_stream) + self._dispatcher.dispatch_initialize(sdfg, dfg_scope, state_id, + child, function_stream, + callsite_stream) + + self._dispatcher.dispatch_subgraph( + sdfg, + dfg_scope, + state_id, + function_stream, + callsite_stream, + skip_entry_node=True) diff --git a/dace/codegen/targets/target.py b/dace/codegen/targets/target.py new file mode 100644 index 0000000000..73e10ba6cc --- /dev/null +++ b/dace/codegen/targets/target.py @@ -0,0 +1,570 @@ +import os +import shutil # which +import dace +from dace import types +from dace.graph import nodes, nxutil + + +class TargetCodeGenerator(object): + """ Interface dictating functions that generate code for: + * Array allocation/deallocation/initialization/copying + * Scope (map, consume) code generation + """ + + def get_generated_codeobjects(self): + """ Returns a list of generated `CodeObject` classes corresponding + to files with generated code. + @see: CodeObject + """ + raise NotImplementedError('Abstract class') + + @property + def has_initializer(self): + """ Returns True if the target generates a `__dace_init_` + function that should be called on initialization. """ + raise NotImplementedError('Abstract class') + + @property + def has_finalizer(self): + """ Returns True if the target generates a `__dace_exit_` + function that should be called on finalization. """ + raise NotImplementedError('Abstract class') + + def generate_state(self, sdfg, state, function_stream, callsite_stream): + """ Generates code for an SDFG state, outputting it to the given + code streams. + @param sdfg: The SDFG to generate code from. + @param state: The SDFGState to generate code from. + @param function_stream: A `CodeIOStream` object that will be + generated outside the calling code, for + use when generating global functions. + @param callsite_stream: A `CodeIOStream` object that points + to the current location (call-site) + in the code. + """ + raise NotImplementedError('Abstract class') + + def generate_scope(self, sdfg, dfg_scope, state_id, function_stream, + callsite_stream): + """ Generates code for an SDFG state scope (from a scope-entry node + to its corresponding scope-exit node), outputting it to the given + code streams. + @param sdfg: The SDFG to generate code from. + @param dfg_scope: The `ScopeSubgraphView` to generate code from. + @param state_id: The node ID of the state in the given SDFG. + @param function_stream: A `CodeIOStream` object that will be + generated outside the calling code, for + use when generating global functions. + @param callsite_stream: A `CodeIOStream` object that points + to the current location (call-site) + in the code. + """ + raise NotImplementedError('Abstract class') + + def generate_node(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + """ Generates code for a single node, outputting it to the given + code streams. + @param sdfg: The SDFG to generate code from. + @param dfg: The SDFG state to generate code from. + @param state_id: The node ID of the state in the given SDFG. + @param node: The node to generate code from. + @param function_stream: A `CodeIOStream` object that will be + generated outside the calling code, for + use when generating global functions. + @param callsite_stream: A `CodeIOStream` object that points + to the current location (call-site) + in the code. + """ + raise NotImplementedError('Abstract class') + + def allocate_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + """ Generates code for allocating an array, outputting to the given + code streams. + @param sdfg: The SDFG to generate code from. + @param dfg: The SDFG state to generate code from. + @param state_id: The node ID of the state in the given SDFG. + @param node: The data node to generate allocation for. + @param function_stream: A `CodeIOStream` object that will be + generated outside the calling code, for + use when generating global functions. + @param callsite_stream: A `CodeIOStream` object that points + to the current location (call-site) + in the code. + """ + raise NotImplementedError('Abstract class') + + def initialize_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + """ Generates code for initializing an array, outputting to the given + code streams. + @param sdfg: The SDFG to generate code from. + @param dfg: The SDFG state to generate code from. + @param state_id: The node ID of the state in the given SDFG. + @param node: The data node to generate initialization for. + @param function_stream: A `CodeIOStream` object that will be + generated outside the calling code, for + use when generating global functions. + @param callsite_stream: A `CodeIOStream` object that points + to the current location (call-site) + in the code. + """ + raise NotImplementedError('Abstract class') + + def deallocate_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + """ Generates code for deallocating an array, outputting to the given + code streams. + @param sdfg: The SDFG to generate code from. + @param dfg: The SDFG state to generate code from. + @param state_id: The node ID of the state in the given SDFG. + @param node: The data node to generate deallocation for. + @param function_stream: A `CodeIOStream` object that will be + generated outside the calling code, for + use when generating global functions. + @param callsite_stream: A `CodeIOStream` object that points + to the current location (call-site) + in the code. + """ + raise NotImplementedError('Abstract class') + + def copy_memory(self, sdfg, dfg, state_id, src_node, dst_node, edge, + function_stream, callsite_stream): + """ Generates code for copying memory, either from a data access + node (array/stream) to another, a code node (tasklet/nested + SDFG) to another, or a combination of the two. + @param sdfg: The SDFG to generate code from. + @param dfg: The SDFG state to generate code from. + @param state_id: The node ID of the state in the given SDFG. + @param src_node: The source node to generate copy code for. + @param dst_node: The destination node to generate copy code for. + @param edge: The edge representing the copy (in the innermost + scope, adjacent to either the source or destination + node). + @param function_stream: A `CodeIOStream` object that will be + generated outside the calling code, for + use when generating global functions. + @param callsite_stream: A `CodeIOStream` object that points + to the current location (call-site) + in the code. + """ + raise NotImplementedError('Abstract class') + + +class IllegalCopy(TargetCodeGenerator): + """ A code generator that is triggered when invalid copies are specified + by the SDFG. Only raises an exception on failure. """ + + def copy_memory(self, sdfg, dfg, state_id, src_node, dst_node, edge, + function_stream, callsite_stream): + raise TypeError('Illegal copy! (from ' + str(src_node) + ' to ' + + str(dst_node) + ')') + + +class DefinedType(dace.types.AutoNumber): + """ Data types for `DefinedMemlets`. + @see: DefinedMemlets + """ + Pointer = () + ArrayView = () + Scalar = () + ScalarView = () + Stream = () + StreamArray = () + + +class DefinedMemlets: + """ Keeps track of the type of defined memlets to ensure that they are + referenced correctly in nested scopes and SDFGs. """ + + def __init__(self): + self._scopes = [(None, {})] + + def enter_scope(self, parent): + self._scopes.append((parent, {})) + + def exit_scope(self, parent): + expected, _ = self._scopes.pop() + if expected != parent: + raise ValueError( + "Exited scope {} mismatched current scope {}".format( + parent.name, expected.name)) + + def get(self, name): + for _, scope in reversed(self._scopes): + if name in scope: + return scope[name] + raise KeyError("Variable {} has not been defined".format(name)) + + def add(self, name, connector_type): + if not isinstance(name, str): + raise TypeError( + 'Variable name type cannot be %s' % type(name).__name__) + + for _, scope in reversed(self._scopes): + if name in scope: + err_str = "Shadowing variable {} from type {} to {}".format( + name, scope[name], connector_type) + if dace.config.Config.get_bool("compiler", "allow_shadowing"): + print("WARNING: " + err_str) + else: + raise dace.codegen.codegen.CodegenError(err_str) + self._scopes[-1][1][name] = connector_type + + +############################################################################# + + +class TargetDispatcher(object): + """ Dispatches sub-SDFG generation (according to scope), + storage<->storage copies, and storage<->tasklet copies to targets. """ + + def __init__(self): + self._used_targets = set() + + self._array_dispatchers = { + } # Type: types.StorageType -> TargetCodeGenerator + self._map_dispatchers = { + } # Type: types.ScheduleType -> TargetCodeGenerator + self._copy_dispatchers = {} # Type: (types.StorageType src, + # types.StorageType dst, + # types.ScheduleType dst_schedule) + # -> TargetCodeGenerator + self._node_dispatchers = [] # [(predicate, dispatcher)] + self._generic_node_dispatcher = None # Type: TargetCodeGenerator + self._state_dispatchers = [] # [(predicate, dispatcher)] + self._generic_state_dispatcher = None # Type: TargetCodeGenerator + + self._defined_vars = DefinedMemlets() + + @property + def defined_vars(self): + """ Returns a list of defined variables. + @rtype: DefinedMemlets + """ + return self._defined_vars + + @property + def used_targets(self): + """ Returns a list of targets (code generators) that were triggered + during generation. """ + return self._used_targets + + def register_state_dispatcher(self, dispatcher, predicate=None): + """ Registers a code generator that processes a single state, calling + `generate_state`. + @param dispatcher: The code generator to use. + @param predicate: A lambda function that accepts the SDFG and + state, and triggers the code generator when True + is returned. If None, registers `dispatcher` + as the default state dispatcher. + @see: TargetCodeGenerator + """ + + if not hasattr(dispatcher, "generate_state"): + raise TypeError("State dispatcher \"{}\" does not " + "implement \"generate_state\"".format(dispatcher)) + if predicate is None: + self._generic_state_dispatcher = dispatcher + else: + self._state_dispatchers.append((predicate, dispatcher)) + + def get_generic_state_dispatcher(self): + """ Returns the default state dispatcher. """ + return self._generic_state_dispatcher + + def get_predicated_state_dispatchers(self): + """ Returns a list of state dispatchers with predicates. """ + return list(self._state_dispatchers) + + def register_node_dispatcher(self, dispatcher, predicate=None): + """ Registers a code generator that processes a single node, calling + `generate_node`. + @param dispatcher: The code generator to use. + @param predicate: A lambda function that accepts the SDFG, state, + and node, and triggers the code generator when + True is returned. If None, registers `dispatcher` + as the default node dispatcher. + @see: TargetCodeGenerator + """ + if not hasattr(dispatcher, "generate_node"): + raise TypeError("Node dispatcher must " + "implement \"generate_node\"") + if predicate is None: + self._generic_node_dispatcher = dispatcher + else: + self._node_dispatchers.append((predicate, dispatcher)) + + def get_generic_node_dispatcher(self): + """ Returns the default node dispatcher. """ + return self._generic_node_dispatcher + + def get_predicated_node_dispatchers(self): + """ Returns a list of node dispatchers with predicates. """ + return list(self._node_dispatchers) + + def register_map_dispatcher(self, schedule_type, func): + """ Registers a function that processes a scope, used when calling + `dispatch_subgraph` and `dispatch_scope`. + @param schedule_type: The scope schedule that triggers `func`. + @param func: A TargetCodeGenerator object that contains an + implementation of `generate_scope`. + @see: TargetCodeGenerator + """ + if isinstance(schedule_type, list): + for stype in schedule_type: + self.register_map_dispatcher(stype, func) + return + + if not isinstance(schedule_type, types.ScheduleType): raise TypeError + if not isinstance(func, TargetCodeGenerator): raise TypeError + if schedule_type in self._map_dispatchers: + raise ValueError('Schedule already mapped to ' + + str(self._map_dispatchers[schedule_type])) + self._map_dispatchers[schedule_type] = func + + def register_array_dispatcher(self, storage_type, func): + """ Registers a function that processes data allocation, + initialization, and deinitialization. Used when calling + `dispatch_allocate/deallocate/initialize`. + @param storage_type: The data storage type that triggers `func`. + @param func: A TargetCodeGenerator object that contains an + implementation of data memory management functions. + @see: TargetCodeGenerator + """ + if isinstance(storage_type, list): + for stype in storage_type: + self.register_array_dispatcher(stype, func) + return + + if not isinstance(storage_type, types.StorageType): raise TypeError + if not isinstance(func, TargetCodeGenerator): raise TypeError + self._array_dispatchers[storage_type] = func + + def register_copy_dispatcher(self, src_storage, dst_storage, dst_schedule, + func): + """ Registers code generation of data-to-data (or data from/to + tasklet, if src/dst storage is StorageType.Register) copy + functions. Can also be target-schedule specific, or + dst_schedule=None if the function will be invoked on any schedule. + @param src_storage: The source data storage type that triggers + `func`. + @param dst_storage: The destination data storage type that + triggers `func`. + @param dst_schedule: An optional destination scope schedule type + that triggers `func`. + @param func: A TargetCodeGenerator object that contains an + implementation of `copy_memory`. + @see: TargetCodeGenerator + """ + + if not isinstance(src_storage, types.StorageType): raise TypeError + if not isinstance(dst_storage, types.StorageType): raise TypeError + if (dst_schedule is not None + and not isinstance(dst_schedule, types.ScheduleType)): + raise TypeError + if not isinstance(func, TargetCodeGenerator): raise TypeError + + self._copy_dispatchers[(src_storage, dst_storage, dst_schedule)] = \ + func + + def dispatch_state(self, sdfg, state, function_stream, callsite_stream): + """ Dispatches a code generator for an SDFG state. """ + + self.defined_vars.enter_scope(state) + # Check if the state satisfies any predicates that delegate to a + # specific code generator + satisfied_dispatchers = [ + dispatcher for pred, dispatcher in self._state_dispatchers + if pred(sdfg, state) is True + ] + num_satisfied = len(satisfied_dispatchers) + if num_satisfied > 1: + raise RuntimeError( + "Multiple predicates satisfied for {}: {}".format( + state, ", ".join( + [type(x).__name__ for x in satisfied_dispatchers]))) + elif num_satisfied == 1: + satisfied_dispatchers[0].generate_state( + sdfg, state, function_stream, callsite_stream) + else: # num_satisfied == 0 + # Otherwise use the generic code generator (CPU) + self._generic_state_dispatcher.generate_state( + sdfg, state, function_stream, callsite_stream) + self.defined_vars.exit_scope(state) + + def dispatch_subgraph(self, + sdfg, + dfg, + state_id, + function_stream, + callsite_stream, + skip_entry_node=False): + """ Dispatches a code generator for a scope subgraph of an + `SDFGState`. """ + + start_nodes = list( + v for v in dfg.nodes() if len(list(dfg.predecessors(v))) == 0) + + # Mark nodes to skip in order to be able to skip + nodes_to_skip = set() + + if skip_entry_node: + assert len(start_nodes) == 1 + nodes_to_skip.add(start_nodes[0]) + + for v in nxutil.dfs_topological_sort(dfg, start_nodes): + if v in nodes_to_skip: + continue + + if isinstance(v, nodes.MapEntry): + scope_subgraph = sdfg.find_state(state_id).scope_subgraph(v) + + # Propagate parallelism + if dfg.is_parallel(): + scope_subgraph.set_parallel_parent(dfg.get_parallel_parent) + + assert not dfg.is_parallel() or scope_subgraph.is_parallel() + self.dispatch_scope(v.map.schedule, sdfg, scope_subgraph, + state_id, function_stream, callsite_stream) + + # Skip scope subgraph nodes + #print(scope_subgraph.nodes()) + nodes_to_skip.update(scope_subgraph.nodes()) + else: + self.dispatch_node(sdfg, dfg, state_id, v, function_stream, + callsite_stream) + + def dispatch_node(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + """ Dispatches a code generator for a single node. """ + + # Check if the node satisfies any predicates that delegate to a + # specific code generator + satisfied_dispatchers = [ + dispatcher for pred, dispatcher in self._node_dispatchers + if pred(sdfg, node) + ] + num_satisfied = len(satisfied_dispatchers) + if num_satisfied > 1: + raise RuntimeError( + "Multiple predicates satisfied for {}: {}".format( + node, ", ".join( + [type(x).__name__ for x in satisfied_dispatchers]))) + elif num_satisfied == 1: + self._used_targets.add(satisfied_dispatchers[0]) + satisfied_dispatchers[0].generate_node( + sdfg, dfg, state_id, node, function_stream, callsite_stream) + else: # num_satisfied == 0 + # Otherwise use the generic code generator (CPU) + self._used_targets.add(self._generic_node_dispatcher) + self._generic_node_dispatcher.generate_node( + sdfg, dfg, state_id, node, function_stream, callsite_stream) + + def dispatch_scope(self, map_schedule, sdfg, sub_dfg, state_id, + function_stream, callsite_stream): + """ Dispatches a code generator function for a scope in an SDFG + state. """ + entry_node = sub_dfg.source_nodes()[0] + self.defined_vars.enter_scope(entry_node) + self._used_targets.add(self._map_dispatchers[map_schedule]) + self._map_dispatchers[map_schedule].generate_scope( + sdfg, sub_dfg, state_id, function_stream, callsite_stream) + self.defined_vars.exit_scope(entry_node) + + def dispatch_allocate(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + """ Dispatches a code generator for data allocation. """ + + nodedesc = node.desc(sdfg) + storage = (nodedesc.storage if not isinstance(node, nodes.Tasklet) else + types.StorageType.Register) + self._used_targets.add(self._array_dispatchers[storage]) + + self._array_dispatchers[storage].allocate_array( + sdfg, dfg, state_id, node, function_stream, callsite_stream) + + def dispatch_initialize(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + """ Dispatches a code generator for a data initialization. """ + + nodedesc = node.desc(sdfg) + storage = (nodedesc.storage if not isinstance(node, nodes.Tasklet) else + types.StorageType.Register) + self._used_targets.add(self._array_dispatchers[storage]) + self._array_dispatchers[storage].initialize_array( + sdfg, dfg, state_id, node, function_stream, callsite_stream) + + def dispatch_deallocate(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + """ Dispatches a code generator for a data deallocation. """ + + nodedesc = node.desc(sdfg) + storage = (nodedesc.storage if not isinstance(node, nodes.Tasklet) else + types.StorageType.Register) + self._used_targets.add(self._array_dispatchers[storage]) + + self._array_dispatchers[storage].deallocate_array( + sdfg, dfg, state_id, node, function_stream, callsite_stream) + + # Dispatches copy code for a memlet + def dispatch_copy(self, src_node, dst_node, edge, sdfg, dfg, state_id, + function_stream, output_stream): + """ Dispatches a code generator for a memory copy operation. """ + + if isinstance(src_node, nodes.CodeNode): + src_storage = types.StorageType.Register + else: + src_storage = src_node.desc(sdfg).storage + + if isinstance(dst_node, nodes.CodeNode): + dst_storage = types.StorageType.Register + else: + dst_storage = dst_node.desc(sdfg).storage + + if (isinstance(src_node, nodes.Tasklet) + and not isinstance(dst_node, nodes.Tasklet)): + # Special case: Copying from a tasklet to an array, schedule of + # the copy is in the copying tasklet + dst_schedule_node = dfg.scope_dict()[src_node] + else: + dst_schedule_node = dfg.scope_dict()[dst_node] + + if dst_schedule_node is not None: + dst_schedule = dst_schedule_node.map.schedule + else: + dst_schedule = None + + if (src_storage, dst_storage, dst_schedule) in self._copy_dispatchers: + target = self._copy_dispatchers[(src_storage, dst_storage, + dst_schedule)] + self._used_targets.add(target) + target.copy_memory(sdfg, dfg, state_id, src_node, dst_node, edge, + function_stream, output_stream) + elif (src_storage, dst_storage, None) in self._copy_dispatchers: + target = self._copy_dispatchers[(src_storage, dst_storage, None)] + self._used_targets.add(target) + target.copy_memory(sdfg, dfg, state_id, src_node, dst_node, edge, + function_stream, output_stream) + else: + raise RuntimeError('Copy dispatcher for %s->%s with schedule %s' % + (str(src_storage), str(dst_storage), + str(dst_schedule)) + ' not found') + + +def make_absolute(path): + if os.path.isfile(path): + if os.path.isabs(path): + # Path is abolute, we're happy + return path + else: + # Path is relative: make it absolute + return os.path.abspath(path) + else: + # This is not a path, probably just an executable name, such + # as "g++". Try to find it on the PATH + executable = shutil.which(path) + if not executable: + raise ValueError("Could not find executable \"{}\"".format(path)) + return executable diff --git a/dace/codegen/targets/xilinx.py b/dace/codegen/targets/xilinx.py new file mode 100644 index 0000000000..a836d07387 --- /dev/null +++ b/dace/codegen/targets/xilinx.py @@ -0,0 +1,1683 @@ +from six import StringIO +import collections +import functools +import os +import itertools +import re +import sympy as sp + +import dace +from dace import subsets +from dace.config import Config +from dace.frontend import operations +from dace.graph import nodes +from dace.sdfg import ScopeSubgraphView, find_input_arraynode, find_output_arraynode +from dace.codegen.codeobject import CodeObject +from dace.codegen.prettycode import CodeIOStream +from dace.codegen.targets.target import (TargetCodeGenerator, IllegalCopy, + make_absolute, DefinedType) +from dace.codegen.targets.cpu import cpp_offset_expr, cpp_array_expr +from dace.codegen.targets import cpu, cuda + +from dace.codegen import cppunparse + +REDUCTION_TYPE_TO_HLSLIB = { + dace.types.ReductionType.Min: "hlslib::op::Min", + dace.types.ReductionType.Max: "hlslib::op::Max", + dace.types.ReductionType.Sum: "hlslib::op::Sum", + dace.types.ReductionType.Product: "hlslib::op::Product", + dace.types.ReductionType.Logical_And: "hlslib::op::And", +} + + +class XilinxCodeGen(TargetCodeGenerator): + """ Xilinx FPGA code generator. """ + target_name = 'xilinx' + title = 'Xilinx' + language = 'hls' + + def __init__(self, frame_codegen, sdfg): + self._in_device_code = False + self._cpu_codegen = None + self._frame = frame_codegen + self._dispatcher = frame_codegen.dispatcher + + self._global_sdfg = sdfg + self._program_name = sdfg.name + + # Verify that we did not miss the allocation of any global arrays, even + # if they're nested deep in the SDFG + self._allocated_global_arrays = set() + self._unrolled_pes = set() + + # Register dispatchers + self._cpu_codegen = self._dispatcher.get_generic_node_dispatcher() + + self._host_codes = [] + self._kernel_codes = [] + + # Register additional Xilinx dispatchers + self._dispatcher.register_map_dispatcher( + [dace.types.ScheduleType.FPGA_Device], self) + + self._dispatcher.register_state_dispatcher( + self, + predicate=lambda sdfg, state: len(state.data_nodes()) > 0 and all([ + n.desc(sdfg).storage in [ + dace.types.StorageType.FPGA_Global, + dace.types.StorageType.FPGA_Local, + dace.types.StorageType.FPGA_Registers] + for n in state.data_nodes()])) + + self._dispatcher.register_node_dispatcher( + self, predicate=lambda *_: self._in_device_code) + + xilinx_storage = [ + dace.types.StorageType.FPGA_Global, + dace.types.StorageType.FPGA_Local, + dace.types.StorageType.FPGA_Registers, + ] + self._dispatcher.register_array_dispatcher(xilinx_storage, self) + + # Register permitted copies + for storage_from in itertools.chain(xilinx_storage, + [dace.types.StorageType.Register]): + for storage_to in itertools.chain( + xilinx_storage, [dace.types.StorageType.Register]): + if (storage_from == dace.types.StorageType.Register + and storage_to == dace.types.StorageType.Register): + continue + self._dispatcher.register_copy_dispatcher( + storage_from, storage_to, None, self) + self._dispatcher.register_copy_dispatcher( + dace.types.StorageType.FPGA_Global, + dace.types.StorageType.CPU_Heap, None, self) + self._dispatcher.register_copy_dispatcher( + dace.types.StorageType.FPGA_Global, + dace.types.StorageType.CPU_Stack, None, self) + self._dispatcher.register_copy_dispatcher( + dace.types.StorageType.CPU_Heap, + dace.types.StorageType.FPGA_Global, None, self) + self._dispatcher.register_copy_dispatcher( + dace.types.StorageType.CPU_Stack, + dace.types.StorageType.FPGA_Global, None, self) + + @property + def has_initializer(self): + return True + + @property + def has_finalizer(self): + return False + + @staticmethod + def cmake_options(): + compiler = make_absolute( + Config.get("compiler", "xilinx", "executable")) + host_flags = Config.get("compiler", "xilinx", "host_flags") + synthesis_flags = Config.get("compiler", "xilinx", "synthesis_flags") + build_flags = Config.get("compiler", "xilinx", "build_flags") + mode = Config.get("compiler", "xilinx", "mode") + target_platform = Config.get("compiler", "xilinx", "platform") + enable_debugging = ("ON" + if Config.get_bool("compiler", "xilinx", + "enable_debugging") else "OFF") + options = [ + "-DSDACCEL_ROOT_DIR={}".format( + os.path.dirname(os.path.dirname(compiler))), + "-DDACE_XILINX_HOST_FLAGS=\"{}\"".format(host_flags), + "-DDACE_XILINX_SYNTHESIS_FLAGS=\"{}\"".format(synthesis_flags), + "-DDACE_XILINX_BUILD_FLAGS=\"{}\"".format(build_flags), + "-DDACE_XILINX_MODE={}".format(mode), + "-DDACE_XILINX_TARGET_PLATFORM=\"{}\"".format(target_platform), + "-DDACE_XILINX_ENABLE_DEBUGGING={}".format(enable_debugging), + ] + return options + + def generate_state(self, sdfg, state, function_stream, callsite_stream): + """ Generate a kernel that runs all connected components within a state + as concurrent dataflow modules. """ + + state_id = sdfg.node_id(state) + + # Determine independent components + subgraphs = dace.sdfg.concurrent_subgraphs(state) + + # Generate kernel code + shared_transients = set(sdfg.shared_transients()) + if not self._in_device_code: + # Allocate global memory transients, unless they are shared with + # other states + all_transients = set(state.all_transients()) + allocated = set(shared_transients) + for node in state.data_nodes(): + data = node.desc(sdfg) + if node.data not in all_transients or node.data in allocated: + continue + if data.storage != dace.types.StorageType.FPGA_Global: + continue + allocated.add(node.data) + self._dispatcher.dispatch_allocate(sdfg, state, state_id, node, + function_stream, + callsite_stream) + self._dispatcher.dispatch_initialize(sdfg, state, state_id, + node, function_stream, + callsite_stream) + # Generate kernel code + self.generate_kernel(sdfg, state, state.label, subgraphs, + function_stream, callsite_stream) + else: # self._in_device_code == True + to_allocate = dace.sdfg.local_transients(sdfg, state, None) + allocated = set() + for node in state.data_nodes(): + data = node.desc(sdfg) + if node.data not in to_allocate or node.data in allocated: + continue + # Make sure there are no global transients in the nested state + # that are thus not gonna be allocated + if data.storage == dace.types.StorageType.FPGA_Global: + raise dace.codegen.codegen.CodegenError( + "Cannot allocate global memory from device code.") + allocated.add(data) + # Allocate transients + self._dispatcher.dispatch_allocate(sdfg, state, state_id, node, + function_stream, + callsite_stream) + self._dispatcher.dispatch_initialize(sdfg, state, state_id, + node, function_stream, + callsite_stream) + self.generate_nested_state(sdfg, state, state.label, subgraphs, + function_stream, callsite_stream) + + @staticmethod + def shared_data(subgraphs): + """ Returns a set of data objects that are shared between two or more + of the specified subgraphs. """ + shared = set() + if len(subgraphs) >= 2: + seen = {} + for sg in subgraphs: + for node in sg: + if isinstance(node, dace.graph.nodes.AccessNode): + if node.data in seen: + if seen[node.data] != sg: + shared.add(node.data) + else: + seen[node.data] = sg + return shared + + @staticmethod + def global_transient_nodes(subgraphs): + """ Generator that returns all transient global arrays nested in the + passed subgraphs on the form (is_output, AccessNode). """ + seen = set() + for subgraph in subgraphs: + for n, scope in subgraph.all_nodes_recursive(): + if (isinstance(n, dace.graph.nodes.AccessNode) + and n.desc(sdfg).transient and n.desc(sdfg).storage == + dace.types.StorageType.FPGA_Global): + if n.data in seen: + continue + seen.add(n.data) + if scope.out_degree(n) > 0: + yield (False, n) + if scope.in_degree(n) > 0: + yield (True, n) + + @staticmethod + def make_parameters(sdfg, state, subgraphs): + """ Determines the parameters that must be passed to the passed list of + subgraphs, as well as to the global kernel. """ + + # Get a set of data nodes that are shared across subgraphs + shared_data = XilinxCodeGen.shared_data(subgraphs) + + # For some reason the array allocation dispatcher takes nodes, not + # arrays. Build a dictionary of arrays to arbitrary data nodes + # referring to them. + data_to_node = {} + + global_data_params = [] + top_level_local_data = [] + subgraph_params = collections.OrderedDict() # {subgraph: [params]} + nested_global_transients = [] + nested_global_transients_seen = set() + for subgraph in subgraphs: + data_to_node.update({ + node.data: node + for node in subgraph.nodes() + if isinstance(node, dace.graph.nodes.AccessNode) + }) + subsdfg = subgraph.parent + candidates = [] # type: List[Tuple[bool,str,Data]] + # [(is an output, dataname string, data object)] + for n in subgraph.source_nodes(): + candidates += [(False, e.data.data, + subsdfg.arrays[e.data.data]) + for e in state.in_edges(n)] + for n in subgraph.sink_nodes(): + candidates += [(True, e.data.data, subsdfg.arrays[e.data.data]) + for e in state.out_edges(n)] + # Find other data nodes that are used internally + for n, scope in subgraph.all_nodes_recursive(): + if isinstance(n, dace.graph.nodes.AccessNode): + # Add nodes if they are outer-level, or an inner-level + # transient (inner-level inputs/outputs are just connected + # to data in the outer layers, whereas transients can be + # independent). + if scope == subgraph or n.desc(scope).transient: + if scope.out_degree(n) > 0: + candidates.append((False, n.data, n.desc(scope))) + if scope.in_degree(n) > 0: + candidates.append((True, n.data, n.desc(scope))) + if scope != subgraph: + if (isinstance(n.desc(scope), dace.data.Array) + and n.desc(scope).storage == + dace.types.StorageType.FPGA_Global and + n.data not in nested_global_transients_seen + ): + nested_global_transients.append(n) + nested_global_transients_seen.add(n.data) + subgraph_params[subgraph] = [] + # Differentiate global and local arrays. The former are allocated + # from the host and passed to the device code, while the latter are + # (statically) allocated on the device side. + for is_output, dataname, data in candidates: + if (isinstance(data, dace.data.Array) + or isinstance(data, dace.data.Scalar) + or isinstance(data, dace.data.Stream)): + if data.storage == dace.types.StorageType.FPGA_Global: + subgraph_params[subgraph].append((is_output, dataname, + data)) + if is_output: + global_data_params.append((is_output, dataname, + data)) + else: + global_data_params.append((is_output, dataname, + data)) + elif (data.storage == dace.types.StorageType.FPGA_Local or + data.storage == dace.types.StorageType.FPGA_Registers + ): + if dataname in shared_data: + # Only transients shared across multiple components + # need to be allocated outside and passed as + # parameters + subgraph_params[subgraph].append((is_output, + dataname, data)) + # Resolve the data to some corresponding node to be + # passed to the allocator + top_level_local_data.append(dataname) + else: + raise ValueError("Unsupported storage type: {}".format( + data.storage)) + else: + raise TypeError("Unsupported data type: {}".format( + type(data).__name__)) + subgraph_params[subgraph] = dace.types.deduplicate( + subgraph_params[subgraph]) + + # Deduplicate + global_data_params = dace.types.deduplicate(global_data_params) + top_level_local_data = dace.types.deduplicate(top_level_local_data) + top_level_local_data = [data_to_node[n] for n in top_level_local_data] + + # Get scalar parameters + scalar_parameters = sdfg.scalar_parameters(False) + symbol_parameters = sdfg.undefined_symbols(False) + + return (global_data_params, top_level_local_data, subgraph_params, + scalar_parameters, symbol_parameters, nested_global_transients) + + def generate_nested_state(self, sdfg, state, nest_name, subgraphs, + function_stream, callsite_stream): + + for sg in subgraphs: + + self._dispatcher.dispatch_subgraph( + sdfg, + sg, + sdfg.node_id(state), + function_stream, + callsite_stream, + skip_entry_node=False) + + @staticmethod + def detect_memory_widths(subgraphs): + stack = [] + for sg in subgraphs: + stack += [(n, sg) for n in sg.nodes()] + memory_widths = {} + seen = set() + while len(stack) > 0: + node, graph = stack.pop() + if isinstance(node, dace.graph.nodes.NestedSDFG): + for state in node.sdfg.states(): + stack += [(n, state) for n in state.nodes()] + elif isinstance(node, dace.graph.nodes.AccessNode): + if node in seen: + continue + seen.add(node) + nodedesc = node.desc(graph) + for edge in graph.all_edges(node): + if (isinstance(edge.data, dace.memlet.EmptyMemlet) + or edge.data.data is None): + continue + if node.data not in memory_widths: + if (isinstance(nodedesc, dace.data.Stream) + and nodedesc.veclen != edge.data.veclen): + raise ValueError( + "Vector length on memlet {} ({}) doesn't " + "match vector length of {} ({})".format( + edge.data, edge.data.veclen, node.data, + nodedesc.veclen)) + memory_widths[node.data] = edge.data.veclen + else: + if memory_widths[node.data] != edge.data.veclen: + raise dace.codegen.codegen.CodegenError( + "Inconsistent vector length " + "on FPGA for \"{}\": got {}, had {}".format( + node.data, edge.data.veclen, + memory_widths[node.data])) + return memory_widths + + def generate_kernel(self, sdfg, state, kernel_name, subgraphs, + function_stream, callsite_stream): + + state_id = sdfg.node_id(state) + + (global_data_params, top_level_local_data, subgraph_params, + scalar_parameters, symbol_parameters, + nested_global_transients) = type(self).make_parameters( + sdfg, state, subgraphs) + + # Scalar parameters are never output + sc_parameters = [(False, pname, param) + for pname, param in scalar_parameters] + + symbol_params = [ + v.signature(with_types=True, name=k) + for k, v in symbol_parameters.items() + ] + + # Inspect the vector length of all memlets leading to each memory, to + # make sure that they're consistent, and to allow us to instantiate the + # memories as vector types to enable HLS to generate wider data paths. + # Since we cannot pass this auxiliary data structure to the allocator, + # which is called by the dispatcher, we temporarily store it in the + # codegen object. + self._memory_widths = XilinxCodeGen.detect_memory_widths(subgraphs) + + # Write host code + self.generate_host_code(sdfg, state, kernel_name, + global_data_params + sc_parameters, + symbol_parameters, nested_global_transients, + function_stream, callsite_stream) + if self._in_device_code: + raise CodegenError("Tried to generate kernel from device code") + self._in_device_code = True + self._cpu_codegen._packed_types = True + + # Now we write the device code + module_stream = CodeIOStream() + kernel_stream = CodeIOStream() + + # Write header + module_stream.write("#include \n\n", sdfg) + self._frame.generate_fileheader(sdfg, module_stream) + module_stream.write("\n", sdfg) + + # Build kernel signature + kernel_args = [] + for is_output, dataname, data in global_data_params: + if isinstance(data, dace.data.Array): + kernel_args.append("dace::vec<{}, {}> *{}_{}".format( + data.dtype.ctype, self._memory_widths[dataname], dataname, + "out" if is_output else "in")) + else: + kernel_args.append( + data.signature(with_types=True, name=dataname)) + kernel_args += ([ + arg.signature(with_types=True, name=argname) + for _, argname, arg in scalar_parameters + ] + symbol_params) + + # Write kernel signature + kernel_stream.write( + "DACE_EXPORTED void {}({}) {{\n".format( + kernel_name, ', '.join(kernel_args)), sdfg, state_id) + + # Insert interface pragmas + mapped_args = 0 + for arg in kernel_args: + var_name = re.findall("\w+", arg)[-1] + if "*" in arg: + kernel_stream.write( + "#pragma HLS INTERFACE m_axi port={} " + "offset=slave bundle=gmem{}".format(var_name, mapped_args), + sdfg, state_id) + mapped_args += 1 + + for arg in kernel_args + ["return"]: + var_name = re.findall("\w+", arg)[-1] + kernel_stream.write( + "#pragma HLS INTERFACE s_axilite port={} bundle=control". + format(var_name)) + + # TODO: add special case if there's only one module for niceness + kernel_stream.write("\n#pragma HLS DATAFLOW") + kernel_stream.write("\nHLSLIB_DATAFLOW_INIT();") + + # Actual kernel code generation + self.generate_modules(sdfg, state, kernel_name, subgraphs, + subgraph_params, sc_parameters, + symbol_parameters, top_level_local_data, + function_stream, module_stream, kernel_stream) + + kernel_stream.write("HLSLIB_DATAFLOW_FINALIZE();\n}\n") + self._in_device_code = False + self._cpu_codegen._packed_types = False + + concatenated_code = ( + module_stream.getvalue() + kernel_stream.getvalue()) + + # Store code strings to be passed to compilation phase + self._kernel_codes.append((kernel_name, concatenated_code)) + + # Delete the field we've used to pass this dictionary to the memory + # allocator + del self._memory_widths + self._allocated_global_arrays = set() + + def generate_modules(self, sdfg, state, kernel_name, subgraphs, params, + scalar_parameters, symbol_parameters, + top_level_local_data, function_stream, module_stream, + kernel_stream): + + # Emit allocations + state_id = sdfg.node_id(state) + for node in top_level_local_data: + self._dispatcher.dispatch_allocate(sdfg, state, state_id, node, + module_stream, kernel_stream) + self._dispatcher.dispatch_initialize(sdfg, state, state_id, node, + module_stream, kernel_stream) + + # Module generation + for subgraph in subgraphs: + # Traverse to find first tasklets reachable in topological order + to_traverse = subgraph.source_nodes() + seen = set() + while len(to_traverse) > 0: + n = to_traverse.pop() + if n in seen: + continue + seen.add(n) + if (not isinstance(n, dace.graph.nodes.Tasklet) + and not isinstance(n, dace.graph.nodes.NestedSDFG)): + for e in subgraph.out_edges(n): + if e.dst not in seen: + to_traverse.append(e.dst) + # Name module according to all reached tasklets (can be just one) + labels = [ + n.label.replace(" ", "_") for n in seen + if isinstance(n, dace.graph.nodes.Tasklet) + or isinstance(n, dace.graph.nodes.NestedSDFG) + ] + if len(labels) == 0: + labels = [ + n.label.replace(" ", "_") for n in seen + if isinstance(n, dace.graph.nodes.AccessNode) + ] + if len(labels) == 0: + raise RuntimeError( + "Expected at least one tasklet or data node") + module_name = "_".join(labels) + self.generate_module(sdfg, state, module_name, subgraph, + params[subgraph] + scalar_parameters, + symbol_parameters, function_stream, + module_stream, kernel_stream) + + def generate_scope(self, sdfg, dfg_scope, state_id, function_stream, + callsite_stream): + + if not self._in_device_code: + # If we're not already generating kernel code we need to set up the + # kernel launch + subgraphs = [dfg_scope] + return self.generate_kernel( + sdfg, sdfg.find_state(state_id), + dfg_scope.source_nodes()[0].map.label.replace(" ", "_"), + subgraphs, function_stream, callsite_stream) + + self.generate_node(sdfg, dfg_scope, state_id, + dfg_scope.source_nodes()[0], function_stream, + callsite_stream) + + self._dispatcher.dispatch_subgraph( + sdfg, + dfg_scope, + state_id, + function_stream, + callsite_stream, + skip_entry_node=True) + + def generate_host_code(self, sdfg, state, kernel_name, params, + symbol_parameters, nested_global_transients, + function_stream, callsite_stream): + + state_id = sdfg.node_id(state) + + # We exclude nested transients from the CPU code function call, as they + # have not yet been allocated at this point + nested_transient_set = {n.data for n in nested_global_transients} + + symbol_sigs = [ + v.signature(with_types=True, name=k) + for k, v in symbol_parameters.items() + ] + symbol_names = list(symbol_parameters.keys()) + seen = set(nested_transient_set) + kernel_args_call_wrapper = [] + kernel_args_call_host = [] + for is_output, pname, p in params: + kernel_args_call_wrapper.append(p.signature(False, name=pname)) + # Only pass each array once from the host code + if p in seen: + continue + seen.add(p) + kernel_args_call_host.append(p.signature(False, name=pname)) + kernel_args_call_wrapper += symbol_names + kernel_args_call_host += symbol_names + kernel_args_opencl = (XilinxCodeGen.sdaccel_params( + sdfg, [p for p in params + if p[1] not in nested_transient_set]) + symbol_sigs) + kernel_args_hls = [] + kernel_args_hls_without_vectorization = [] + for is_output, argname, arg in params: + if isinstance(arg, dace.data.Array): + kernel_args_hls.append("dace::vec<{}, {}> *{}_{}".format( + arg.dtype.ctype, self._memory_widths[argname], argname, + "out" if is_output else "in")) + kernel_args_hls_without_vectorization.append( + "{} *{}_{}".format(arg.dtype.ctype, argname, "out" + if is_output else "in")) + else: + kernel_args_hls.append( + arg.signature(with_types=True, name=argname)) + kernel_args_hls_without_vectorization.append( + arg.signature(with_types=True, name=argname)) + kernel_args_hls += symbol_sigs + kernel_args_hls_without_vectorization += symbol_sigs + + kernel_function_name = kernel_name + + #---------------------------------------------------------------------- + # Generate OpenCL host-code + #---------------------------------------------------------------------- + + kernel_file_name = "{}.xclbin".format(kernel_name) + host_function_name = "__dace_runkernel_{}".format(kernel_name) + + # Write OpenCL host function + code = CodeIOStream() + code.write("""\ +// Signature of kernel function (with raw pointers) for argument matching +DACE_EXPORTED void {kernel_function_name}({kernel_args_hls_novec}); + +DACE_EXPORTED void {host_function_name}({kernel_args_opencl}) {{""".format( + kernel_function_name=kernel_function_name, + kernel_args_hls_novec=", ".join( + kernel_args_hls_without_vectorization), + host_function_name=host_function_name, + kernel_args_opencl=", ".join(kernel_args_opencl))) + + # Any extra transients stored in global memory on the FPGA must now be + # allocated and passed to the kernel + for arr_node in nested_global_transients: + self._dispatcher.dispatch_allocate(sdfg, state, None, arr_node, + None, code) + self._dispatcher.dispatch_initialize(sdfg, state, None, arr_node, + None, code) + + code.write("""\ + hlslib::ocl::Program program = + hlslib::ocl::GlobalContext().CurrentlyLoadedProgram(); + auto kernel = program.MakeKernel({kernel_function_name}, "{kernel_function_name}", {kernel_args}); + const std::pair elapsed = kernel.ExecuteTask(); + std::cout << "Kernel executed in " << elapsed.second << " seconds.\\n" << std::flush; +}}""".format( + kernel_function_name=kernel_function_name, + kernel_args=", ".join(kernel_args_call_wrapper))) + + # Store code to be passed to compilation phase + self._host_codes.append((kernel_name, code.getvalue())) + + #---------------------------------------------------------------------- + # Inject header for OpenCL host code in the calling code file + #---------------------------------------------------------------------- + + host_declaration = "\n\nDACE_EXPORTED void {}({});\n\n".format( + host_function_name, ", ".join(kernel_args_opencl)) + function_stream.write(host_declaration, sdfg, state_id, None) + + #---------------------------------------------------------------------- + # Call the OpenCL host function from the callsite + #---------------------------------------------------------------------- + + callsite_stream.write( + "{}({});".format(host_function_name, + ", ".join(kernel_args_call_host)), sdfg, state_id, + None) + + +# Unused? +# def generate_caller_code(self, sdfg, state, kernel_name, params, +# symbol_parameters, function_stream, +# callsite_stream): +# +# state_id = sdfg.node_id(state) +# +# symbol_sigs = [v.ctype + ' ' + k for k, v in symbol_parameters.items()] +# symbol_names = symbol_parameters.keys() +# kernel_args_call = [p.signature(False) for p in params] + symbol_names +# kernel_args_plain = [i.signature() for i in params] + symbol_sigs +# +# kernel_function_name = kernel_name +# +# callsite_stream.write( +# "{}({});".format(kernel_function_name, +# ", ".join(kernel_args_call)), sdfg, state_id, +# None) + + def generate_module(self, sdfg, state, name, subgraph, params, + symbol_parameters, function_stream, module_stream, + kernel_stream): + """Generates a module that will run as a dataflow function in the FPGA + kernel.""" + + state_id = sdfg.node_id(state) + dfg = sdfg.nodes()[state_id] + + symbol_sigs = [ + v.signature(with_types=True, name=k) + for k, v in symbol_parameters.items() + ] + symbol_names = list(symbol_parameters.keys()) + kernel_args_call = [] + kernel_args_module = [] + added = set() + for is_output, pname, p in params: + if isinstance(p, dace.data.Array): + arr_name = "{}_{}".format(pname, "out" if is_output else "in") + kernel_args_call.append(arr_name) + kernel_args_module.append("dace::vec<{}, {}> {}*{}".format( + p.dtype.ctype, self._memory_widths[pname], "const " + if not is_output else "", arr_name)) + else: + # Don't make duplicate arguments for other types than arrays + if pname in added: + continue + added.add(pname) + if isinstance(p, dace.data.Stream): + kernel_args_call.append( + p.signature(with_types=False, name=pname)) + if p.is_stream_array(): + kernel_args_module.append( + "dace::FIFO<{}, {}, {}> {}[{}]".format( + p.dtype.ctype, p.veclen, p.buffer_size, pname, + p.size_string())) + else: + kernel_args_module.append( + "dace::FIFO<{}, {}, {}> &{}".format( + p.dtype.ctype, p.veclen, p.buffer_size, pname)) + else: + kernel_args_call.append( + p.signature(with_types=False, name=pname)) + kernel_args_module.append( + p.signature(with_types=True, name=pname)) + kernel_args_call += symbol_names + kernel_args_module += symbol_sigs + + module_function_name = "module_" + name + + # Unrolling processing elements: if the first scope of the subgraph + # is an unrolled map, generate a processing element for each iteration + scope_dict = subgraph.scope_dict(node_to_children=True) + top_scopes = [ + n for n in scope_dict[None] + if isinstance(n, dace.graph.nodes.EntryNode) + ] + unrolled_loops = 0 + if len(top_scopes) == 1: + scope = top_scopes[0] + if scope.unroll: + self._unrolled_pes.add(scope.map) + kernel_args_call += ", ".join(scope.map.params) + kernel_args_module += ["int " + p for p in scope.params] + for p, r in zip(scope.map.params, scope.map.range): + if len(r) > 3: + raise dace.codegen.codegen.CodegenError( + "Strided unroll not supported") + kernel_stream.write( + "for (int {param} = {begin}; {param} < {end}; " + "{param} += {increment}) {{\n#pragma HLS UNROLL". + format( + param=p, begin=r[0], end=r[1] + 1, increment=r[2])) + unrolled_loops += 1 + + # Generate caller code in top-level function + kernel_stream.write( + "HLSLIB_DATAFLOW_FUNCTION({}, {});".format( + module_function_name, ", ".join(kernel_args_call)), sdfg, + state_id) + + for _ in range(unrolled_loops): + kernel_stream.write("}") + + #---------------------------------------------------------------------- + # Generate kernel code + #---------------------------------------------------------------------- + + self._dispatcher.defined_vars.enter_scope(subgraph) + + module_body_stream = CodeIOStream() + + module_body_stream.write( + "void {}({}) {{".format(module_function_name, + ", ".join(kernel_args_module)), sdfg, + state_id) + + # Construct ArrayInterface wrappers to pack input and output pointers + # to the same global array + in_args = { + argname + for out, argname, arg in params + if isinstance(arg, dace.data.Array) + and arg.storage == dace.types.StorageType.FPGA_Global and not out + } + out_args = { + argname + for out, argname, arg in params + if isinstance(arg, dace.data.Array) + and arg.storage == dace.types.StorageType.FPGA_Global and out + } + if len(in_args) > 0 or len(out_args) > 0: + # Add ArrayInterface objects to wrap input and output pointers to + # the same array + module_body_stream.write("\n") + interfaces_added = set() + for _, argname, arg in params: + if argname in interfaces_added: + continue + interfaces_added.add(argname) + has_in_ptr = argname in in_args + has_out_ptr = argname in out_args + if not has_in_ptr and not has_out_ptr: + continue + in_ptr = ("{}_in".format(argname) if has_in_ptr else "nullptr") + out_ptr = ("{}_out".format(argname) + if has_out_ptr else "nullptr") + module_body_stream.write( + "dace::ArrayInterface<{}, {}> {}({}, {});".format( + arg.dtype.ctype, self._memory_widths[argname], argname, + in_ptr, out_ptr)) + module_body_stream.write("\n") + + # Allocate local transients + data_to_allocate = (set(subgraph.top_level_transients()) - set( + sdfg.shared_transients()) - set([p[1] for p in params])) + allocated = set() + for node in subgraph.nodes(): + if not isinstance(node, nodes.AccessNode): + continue + if node.data not in data_to_allocate or node.data in allocated: + continue + allocated.add(node.data) + self._dispatcher.dispatch_allocate(sdfg, state, state_id, node, + function_stream, + module_body_stream) + self._dispatcher.dispatch_initialize(sdfg, state, state_id, node, + function_stream, + module_body_stream) + + self._dispatcher.dispatch_subgraph( + sdfg, + subgraph, + state_id, + module_stream, + module_body_stream, + skip_entry_node=False) + + module_stream.write(module_body_stream.getvalue(), sdfg, state_id) + module_stream.write("}\n\n") + + self._dispatcher.defined_vars.exit_scope(subgraph) + + def get_generated_codeobjects(self): + + execution_mode = Config.get("compiler", "xilinx", "mode") + sdaccel_dir = os.path.dirname( + os.path.dirname( + make_absolute(Config.get("compiler", "xilinx", "executable")))) + sdaccel_platform = Config.get("compiler", "xilinx", "platform") + + kernel_file_name = "DACE_BINARY_DIR \"{}".format(self._program_name) + if execution_mode == "software_emulation": + kernel_file_name += "_sw_emu.xclbin\"" + xcl_emulation_mode = "sw_emu" + xilinx_sdx = sdaccel_dir + elif execution_mode == "hardware_emulation": + kernel_file_name += "_hw_emu.xclbin\"" + xcl_emulation_mode = "sw_emu" + xilinx_sdx = sdaccel_dir + elif execution_mode == "hardware" or execution_mode == "simulation": + kernel_file_name += "_hw.xclbin\"" + xcl_emulation_mode = None + xilinx_sdx = None + else: + raise dace.codegen.codegen.CodegenError( + "Unknown Xilinx execution mode: {}".format(execution_mode)) + + set_env_vars = "" + set_str = "dace::set_environment_variable(\"{}\", \"{}\");\n" + unset_str = "dace::unset_environment_variable(\"{}\");\n" + set_env_vars += (set_str.format("XCL_EMULATION_MODE", + xcl_emulation_mode) + if xcl_emulation_mode is not None else + unset_str.format("XCL_EMULATION_MODE")) + set_env_vars += (set_str.format("XILINX_SDX", xilinx_sdx) + if xilinx_sdx is not None else + unset_str.format("XILINX_SDX")) + + host_code = CodeIOStream() + host_code.write("""\ +#include "dace/xilinx/host.h" +#include "dace/dace.h" +#include \n\n""") + + self._frame.generate_fileheader(self._global_sdfg, host_code) + + host_code.write(""" +DACE_EXPORTED int __dace_init_xilinx({signature}) {{ + {environment_variables} + hlslib::ocl::GlobalContext().MakeProgram({kernel_file_name}); + return 0; +}} + +{host_code}""".format( + signature=self._global_sdfg.signature(), + environment_variables=set_env_vars, + kernel_file_name=kernel_file_name, + host_code="".join([ + "{separator}\n// Kernel: {kernel_name}" + "\n{separator}\n\n{code}\n\n".format( + separator="/" * 79, kernel_name=name, code=code) + for (name, code) in self._host_codes + ]))) + + host_code_obj = CodeObject(self._program_name + "_host", + host_code.getvalue(), "cpp", XilinxCodeGen, + "Xilinx") + + kernel_code_objs = [ + CodeObject("kernel_" + kernel_name, code, "cpp", XilinxCodeGen, + "Xilinx") for (kernel_name, code) in self._kernel_codes + ] + + return [host_code_obj] + kernel_code_objs + + def allocate_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + result = StringIO() + nodedesc = node.desc(sdfg) + arrsize = " * ".join([ + cppunparse.pyexpr2cpp(dace.symbolic.symstr(s)) + for s in nodedesc.strides + ]) + is_dynamically_sized = any( + dace.symbolic.issymbolic(s, sdfg.constants) + for s in nodedesc.strides) + + dataname = node.data + + if isinstance(nodedesc, dace.data.Stream): + + if not self._in_device_code: + raise dace.codegen.codegen.CodegenError( + "Cannot allocate FIFO from CPU code: {}".format(node.data)) + + if is_dynamically_sized: + raise dace.codegen.codegen.CodegenError( + "Arrays of streams cannot have dynamic size on FPGA") + + if nodedesc.buffer_size < 1: + raise dace.codegen.codegen.CodegenError( + "Streams cannot be unbounded on FPGA") + + buffer_length_dynamically_sized = ( + isinstance(nodedesc.buffer_size, sp.Expr) + and len(nodedesc.free_symbols) > 0) + + if buffer_length_dynamically_sized: + raise dace.codegen.codegen.CodegenError( + "Buffer length of stream cannot have dynamic size on FPGA") + + if arrsize != "1": + is_stream_array = True + else: + is_stream_array = False + + if is_stream_array: + result.write("dace::FIFO<{}, {}, {}> {}[{}];\n".format( + nodedesc.dtype.ctype, nodedesc.veclen, + nodedesc.buffer_size, dataname, arrsize)) + result.write("dace::SetNames({}, \"{}\", {});".format( + dataname, dataname, arrsize)) + self._dispatcher.defined_vars.add(dataname, + DefinedType.StreamArray) + else: + result.write("dace::FIFO<{}, {}, {}> {}(\"{}\");".format( + nodedesc.dtype.ctype, nodedesc.veclen, + nodedesc.buffer_size, dataname, dataname)) + self._dispatcher.defined_vars.add(dataname, DefinedType.Stream) + + elif isinstance(nodedesc, dace.data.Array): + + if nodedesc.storage == dace.types.StorageType.FPGA_Global: + + if self._in_device_code: + + if nodedesc not in self._allocated_global_arrays: + raise RuntimeError("Cannot allocate global array " + "from device code: {} in {}".format( + node.label, sdfg.name)) + + else: + + devptr_name = dataname + if isinstance(nodedesc, dace.data.Array): + # TODO: Distinguish between read, write, and + # read+write + # TODO: Handle memory banks (location?) + self._allocated_global_arrays.add(node.data) + result.write( + "auto {} = hlslib::ocl::GlobalContext()." + "MakeBuffer<{}, hlslib::ocl::Access::readWrite>" + "({});".format(dataname, nodedesc.dtype.ctype, + arrsize)) + self._dispatcher.defined_vars.add( + dataname, DefinedType.Pointer) + + elif (nodedesc.storage == dace.types.StorageType.FPGA_Local or + nodedesc.storage == dace.types.StorageType.FPGA_Registers): + + if not self._in_device_code: + raise dace.codegen.codegen.CodegenError( + "Tried to allocate local FPGA memory " + "outside device code: {}".format(dataname)) + if is_dynamically_sized: + raise ValueError( + "Dynamic allocation of FPGA fast memory not allowed") + + # Absorb vector size into type and adjust array size + # accordingly + veclen = self._memory_widths[node.data] + generate_scalar = False + if veclen > 1: + arrsize_symbolic = functools.reduce( + sp.mul.Mul, nodedesc.strides) + arrsize_eval = dace.symbolic.eval( + arrsize_symbolic / veclen) + if cpu.sym2cpp(arrsize_eval) == "1": + generate_scalar = True + arrsize_vec = "({}) / {}".format(arrsize, veclen) + else: + arrsize_vec = arrsize + + # If the array degenerates to a single element because of + # vectorization, generate the variable as a scalar instead of + # an array of size 1 + if generate_scalar: + result.write("dace::vec<{}, {}> {};\n".format( + nodedesc.dtype.ctype, veclen, dataname)) + self._dispatcher.defined_vars.add(dataname, + DefinedType.Scalar) + else: + result.write("dace::vec<{}, {}> {}[{}];\n".format( + nodedesc.dtype.ctype, veclen, dataname, arrsize_vec)) + self._dispatcher.defined_vars.add(dataname, + DefinedType.Pointer) + if nodedesc.storage == dace.types.StorageType.FPGA_Registers: + result.write("#pragma HLS ARRAY_PARTITION variable={} " + "complete\n".format(dataname)) + elif len(nodedesc.shape) > 1: + result.write("#pragma HLS ARRAY_PARTITION variable={} " + "block factor={}\n".format( + dataname, nodedesc.shape[-2])) + # result.write( + # "#pragma HLS DEPENDENCE variable={} false".format( + # dataname)) + + else: + raise NotImplementedError("Xilinx: Unimplemented storage type " + + str(nodedesc.storage)) + + elif isinstance(nodedesc, dace.data.Scalar): + + result.write("{} {};\n".format(nodedesc.dtype.ctype, dataname)) + self._dispatcher.defined_vars.add(dataname, DefinedType.Scalar) + + else: + raise TypeError("Unhandled data type: {}".format( + type(nodedesc).__name__)) + + callsite_stream.write(result.getvalue(), sdfg, state_id, node) + + def deallocate_array(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + pass # Handled by destructor + + def _emit_copy(self, sdfg, state_id, src_node, src_storage, dst_node, + dst_storage, dst_schedule, edge, dfg, callsite_stream): + + u, v, memlet = edge.src, edge.dst, edge.data + + cpu_storage_types = [ + dace.types.StorageType.CPU_Heap, dace.types.StorageType.CPU_Stack, + dace.types.StorageType.CPU_Pinned + ] + fpga_storage_types = [ + dace.types.StorageType.FPGA_Global, + dace.types.StorageType.FPGA_Local, + dace.types.StorageType.FPGA_Registers, + ] + + # Determine directionality + if isinstance(src_node, + nodes.AccessNode) and memlet.data == src_node.data: + outgoing_memlet = True + elif isinstance(dst_node, + nodes.AccessNode) and memlet.data == dst_node.data: + outgoing_memlet = False + else: + raise LookupError("Memlet does not point to any of the nodes") + + data_to_data = (isinstance(src_node, nodes.AccessNode) + and isinstance(dst_node, nodes.AccessNode)) + + host_to_device = (data_to_data and src_storage in cpu_storage_types and + dst_storage == dace.types.StorageType.FPGA_Global) + device_to_host = (data_to_data + and src_storage == dace.types.StorageType.FPGA_Global + and dst_storage in cpu_storage_types) + device_to_device = ( + data_to_data and src_storage == dace.types.StorageType.FPGA_Global + and dst_storage == dace.types.StorageType.FPGA_Global) + + if (host_to_device or device_to_host) and self._in_device_code: + raise RuntimeError( + "Cannot copy between host and device from device") + + if (host_to_device or device_to_host + or (device_to_device and not self._in_device_code)): + + dims = memlet.subset.dims() + copy_shape = memlet.subset.bounding_box_size() + copysize = ' * '.join([ + cppunparse.pyexpr2cpp(dace.symbolic.symstr(s)) + for s in copy_shape + ]) + offset = cpp_array_expr(sdfg, memlet, with_brackets=False) + + if (not sum(copy_shape) == 1 + and (not isinstance(memlet.subset, subsets.Range) + or any([step != 1 for _, _, step in memlet.subset]))): + raise NotImplementedError("Only contiguous copies currently " + "supported for Xilinx FPGA.") + + if host_to_device: + + callsite_stream.write( + "{}.CopyFromHost({}, {}, {});".format( + dst_node.data, (offset if not outgoing_memlet else 0), + copysize, + src_node.data + (" + {}".format(offset) + if outgoing_memlet else "")), sdfg, + state_id, [src_node, dst_node]) + + elif device_to_host: + + callsite_stream.write( + "{}.CopyToHost({}, {}, {});".format( + src_node.data, (offset + if outgoing_memlet else 0), copysize, + dst_node.data + (" + {}".format(offset) + if not outgoing_memlet else "")), + sdfg, state_id, [src_node, dst_node]) + + elif device_to_device: + + callsite_stream.write( + "{}.CopyToDevice({}, {}, {}, {});".format( + src_node.data, (offset + if outgoing_memlet else 0), copysize, + dst_node.data, (offset if not outgoing_memlet else 0)), + sdfg, state_id, [src_node, dst_node]) + + # Reject copying to/from local memory from/to outside the FPGA + elif (data_to_data + and (((src_storage == dace.types.StorageType.FPGA_Local + or src_storage == dace.types.StorageType.FPGA_Registers) + and dst_storage not in fpga_storage_types) or + ((dst_storage == dace.types.StorageType.FPGA_Local + or dst_storage == dace.types.StorageType.FPGA_Registers) + and src_storage not in fpga_storage_types))): + raise NotImplementedError( + "Copies between host memory and FPGA " + "local memory not supported: from {} to {}".format( + src_node, dst_node)) + + elif data_to_data: + + if memlet.wcr is not None: + raise NotImplementedError("WCR not implemented for copy edges") + + # Try to turn into degenerate/strided ND copies + copy_shape, src_strides, dst_strides, src_expr, dst_expr = ( + self._cpu_codegen.memlet_copy_to_absolute_strides( + sdfg, memlet, src_node, dst_node, packed_types=True)) + + ctype = src_node.desc(sdfg).dtype.ctype + + # TODO: detect in which cases we shouldn't unroll + register_to_register = (src_node.desc( + sdfg).storage == dace.types.StorageType.FPGA_Registers + or dst_node.desc(sdfg).storage == + dace.types.StorageType.FPGA_Registers) + + # Loop intro + num_loops = 0 + for i, copy_dim in enumerate(copy_shape): + if copy_dim != 1: + callsite_stream.write( + "for (auto __dace_copy{} = 0; __dace_copy{} < {}; " + "++__dace_copy{}) {{".format(i, i, copy_dim, i)) + if register_to_register: + callsite_stream.write("#pragma HLS UNROLL") + num_loops += 1 + + # Pragmas + if num_loops > 0: + if not register_to_register: + callsite_stream.write("#pragma HLS PIPELINE II=1") + if len(copy_shape) > 1: + callsite_stream.write("#pragma HLS LOOP_FLATTEN") + + # Construct indices (if the length of the stride array is zero, + # resolves to an empty string) + src_index = " + ".join(([""] if len(dst_strides) > 0 else []) + [ + "__dace_copy{} * {}".format(i, cpu.sym2cpp(stride)) + for i, stride in enumerate(src_strides) if copy_shape[i] != 1 + ]) + dst_index = " + ".join(([""] if len(dst_strides) > 0 else []) + [ + "__dace_copy{} * {}".format(i, cpu.sym2cpp(stride)) + for i, stride in enumerate(dst_strides) if copy_shape[i] != 1 + ]) + + src_def_type = self._dispatcher.defined_vars.get(src_node.data) + dst_def_type = self._dispatcher.defined_vars.get(dst_node.data) + + if src_def_type == DefinedType.Stream: + read_expr = src_expr + elif src_def_type == DefinedType.Scalar: + read_expr = src_node.label + else: + read_expr = "dace::Read<{}, {}>({}{})".format( + ctype, memlet.veclen, src_expr, src_index) + + if dst_def_type == DefinedType.Stream: + callsite_stream.write("{}.push({});".format( + dst_expr, read_expr)) + else: + if dst_def_type == DefinedType.Scalar: + write_expr = dst_node.label + callsite_stream.write("dace::Write<{}, {}>({}{}, {});".format( + ctype, memlet.veclen, dst_expr, dst_index, read_expr)) + + # Inject dependence pragmas (DaCe semantics implies no conflict) + for node in [src_node, dst_node]: + if (isinstance(node.desc(sdfg), dace.data.Array) + and node.desc(sdfg).storage in [ + dace.types.StorageType.FPGA_Local, + dace.StorageType.FPGA_Registers + ]): + callsite_stream.write( + "#pragma HLS DEPENDENCE variable={} false".format( + node.data)) + + # Loop outtro + for _ in range(num_loops): + callsite_stream.write("}") + + else: + + self._cpu_codegen.copy_memory(sdfg, dfg, state_id, src_node, + dst_node, edge, None, + callsite_stream) + + @staticmethod + def sdaccel_params(sdfg, kernel_params): + seen = set() + out_params = [] + for is_output, pname, param in kernel_params: + # Since we can have both input and output versions of the same + # array, make sure we only pass it once from the host code + if param in seen: + continue + seen.add(param) + if isinstance(param, dace.data.Array): + out_params.append("hlslib::ocl::Buffer<{}, " + "hlslib::ocl::Access::readWrite> &{}".format( + param.dtype.ctype, pname)) + else: + out_params.append(param.signature(with_types=True, name=pname)) + return out_params + + def get_next_scope_entries(self, sdfg, dfg, scope_entry): + parent_scope_entry = dfg.scope_dict()[scope_entry] + parent_scope = dfg.scope_subgraph(parent_scope_entry) + + # Get all scopes from the same level + all_scopes = [ + node for node in parent_scope.topological_sort() + if isinstance(node, nodes.EntryNode) + ] + + return all_scopes[all_scopes.index(scope_entry) + 1:] + + def generate_node(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + method_name = "_generate_" + type(node).__name__ + # Fake inheritance... use this class' method if it exists, + # otherwise fall back on CPU codegen + if hasattr(self, method_name): + + if hasattr(node, "schedule") and node.schedule not in [ + dace.types.ScheduleType.Default, + dace.types.ScheduleType.FPGA_Device + ]: + # raise dace.codegen.codegen.CodegenError( + # "Cannot produce FPGA code for {} node with schedule {}: ". + # format(type(node).__name__, node.schedule, node)) + print("WARNING: found schedule {} on {} node in FPGA code. " + "Ignoring.".format(node.schedule, + type(node).__name__)) + + getattr(self, method_name)(sdfg, dfg, state_id, node, + function_stream, callsite_stream) + else: + self._cpu_codegen.generate_node(sdfg, dfg, state_id, node, + function_stream, callsite_stream) + + def initialize_array(self, *args, **kwargs): + pass + + def copy_memory(self, sdfg, dfg, state_id, src_node, dst_node, edge, + function_stream, callsite_stream): + + if isinstance(src_node, nodes.Tasklet): + src_storage = dace.types.StorageType.Register + try: + src_parent = dfg.scope_dict()[src_node] + except KeyError: + src_parent = None + dst_schedule = (None + if src_parent is None else src_parent.map.schedule) + else: + src_storage = src_node.desc(sdfg).storage + + if isinstance(dst_node, nodes.Tasklet): + dst_storage = dace.types.StorageType.Register + else: + dst_storage = dst_node.desc(sdfg).storage + + try: + dst_parent = dfg.scope_dict()[dst_node] + except KeyError: + dst_parent = None + dst_schedule = None if dst_parent is None else dst_parent.map.schedule + + state_dfg = sdfg.nodes()[state_id] + + # Emit actual copy + self._emit_copy(sdfg, state_id, src_node, src_storage, dst_node, + dst_storage, dst_schedule, edge, state_dfg, + callsite_stream) + + def _generate_MapEntry(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + + result = callsite_stream + + scope_dict = dfg.scope_dict() + + if node.map in self._unrolled_pes: + + # This is a top-level unrolled map, meaning it has been used to + # replicate processing elements. Don't generate anything here. + pass + + else: + + # Generate nested loops + for i, r in enumerate(node.map.range): + var = node.map.params[i] + begin, end, skip = r + result.write( + "for (auto {} = {}; {} < {}; {} += {}) {{\n".format( + var, cpu.sym2cpp(begin), var, cpu.sym2cpp(end + 1), + var, cpu.sym2cpp(skip)), sdfg, state_id, node) + + # Pipeline innermost loops + scope = dfg.scope_dict(True)[node] + + if node.map.unroll: + result.write("#pragma HLS UNROLL\n", sdfg, state_id, node) + else: + is_innermost = not any( + [isinstance(x, nodes.EntryNode) for x in scope]) + if is_innermost: + result.write( + "#pragma HLS PIPELINE II=1\n#pragma HLS LOOP_FLATTEN", + sdfg, state_id, node) + + if node.map.flatten: + result.write("#pragma HLS LOOP_FLATTEN\n", sdfg, state_id, + node) + + # Emit internal transient array allocation + to_allocate = dace.sdfg.local_transients( + sdfg, sdfg.find_state(state_id), node) + allocated = set() + for child in dfg.scope_dict(node_to_children=True)[node]: + if not isinstance(child, nodes.AccessNode): + continue + if child.data not in to_allocate or child.data in allocated: + continue + allocated.add(child.data) + self._dispatcher.dispatch_allocate(sdfg, dfg, state_id, child, + None, result) + self._dispatcher.dispatch_initialize(sdfg, dfg, state_id, child, + None, result) + + def _generate_MapExit(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + scope_dict = dfg.scope_dict() + entry_node = scope_dict[node] + if entry_node.map in self._unrolled_pes: + # This was generated as unrolled processing elements, no need to + # generate anything here + return + self._cpu_codegen._generate_MapExit(sdfg, dfg, state_id, node, + function_stream, callsite_stream) + + def _generate_Reduce(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + + end_braces = 0 + + axes = node.axes + input_memlet = dfg.in_edges(node)[0].data + src_data = sdfg.arrays[input_memlet.data] + output_edge = dfg.out_edges(node)[0] + output_memlet = output_edge.data + dst_data = sdfg.arrays[output_memlet.data] + + output_type = 'dace::vec<%s, %s>' % (dst_data.dtype.ctype, + output_memlet.veclen) + + # If axes were not defined, use all input dimensions + input_dims = input_memlet.subset.dims() + output_dims = output_memlet.subset.data_dims() + if axes is None: + axes = tuple(range(input_dims)) + output_axes = [a for a in range(input_dims) if a not in axes] + + # Obtain variable names per output and reduction axis + axis_vars = [] + unroll_dim = [] + octr = 0 + for d in range(input_dims): + if d in axes: + axis_vars.append('__i%d' % d) + else: + axis_vars.append('__o%d' % octr) + octr += 1 + if ((isinstance(src_data, dace.data.Stream) + and src_data.is_stream_array()) or + (isinstance(src_data, dace.data.Array) and + src_data.storage == dace.types.StorageType.FPGA_Registers)): + # Unroll reads from registers and stream arrays + unroll_dim.append(True) + else: + unroll_dim.append(False) + + # We want to pipeline the last non-unrolled dimension + pipeline_dim = -1 + for i in itertools.chain(axes, output_axes): + if not unroll_dim[i]: + pipeline_dim = i + + if node.identity is not None: + identity = cpu.sym2cpp(node.identity) + else: + identity = None + + # Initialize accumulator variable if we're collapsing to a single value + all_axes_collapsed = (len(axes) == input_dims) + if all_axes_collapsed: + accumulator = "_{}_accumulator".format(output_memlet.data) + callsite_stream.write("{} {};".format(output_type, accumulator), + sdfg, state_id, node) + + # Generate inner loops (for each collapsed dimension) + input_subset = input_memlet.subset + iterators_inner = ["__i{}".format(axis) for axis in axes] + for i, axis in enumerate(axes): + callsite_stream.write( + 'for (int {var} = {begin}; {var} < {end}; {var} += {skip}) {{'. + format( + var=iterators_inner[i], + begin=input_subset[axis][0], + end=input_subset[axis][1] + 1, + skip=input_subset[axis][2]), sdfg, state_id, node) + if unroll_dim[axis]: + callsite_stream.write("#pragma HLS UNROLL\n") + if axis == pipeline_dim: + callsite_stream.write( + "#pragma HLS PIPELINE II=1\n#pragma HLS LOOP_FLATTEN") + end_braces += 1 + + # Generate outer loops (over different output locations) + output_subset = output_memlet.subset + iterators_outer = ["__o{}".format(axis) for axis in range(output_dims)] + for i, axis in enumerate(output_axes): + callsite_stream.write( + 'for (int {var} = {begin}; {var} < {end}; {var} += {skip}) {{'. + format( + var=iterators_outer[i], + begin=output_subset[i][0], + end=output_subset[i][1] + 1, + skip=output_subset[i][2]), sdfg, state_id, node) + if unroll_dim[axis]: + callsite_stream.write("#pragma HLS UNROLL\n") + if axis == pipeline_dim: + callsite_stream.write( + "#pragma HLS PIPELINE II=1\n#pragma HLS LOOP_FLATTEN") + end_braces += 1 + + # Determine reduction type + reduction_type = operations.detect_reduction_type(node.wcr) + if reduction_type == dace.types.ReductionType.Custom: + raise NotImplementedError("Custom reduction for FPGA is NYI") + + # Input and output variables + out_var = (accumulator + if all_axes_collapsed else cpp_array_expr( + sdfg, + output_memlet, + offset=iterators_outer, + relative_offset=False)) + in_var = cpp_array_expr( + sdfg, input_memlet, offset=axis_vars, relative_offset=False) + + # Call library function to perform reduction + reduction_cpp = "dace::Reduce<{}, {}, {}, {}<{}>>".format( + dst_data.dtype.ctype, input_memlet.veclen, output_memlet.veclen, + REDUCTION_TYPE_TO_HLSLIB[reduction_type], dst_data.dtype.ctype) + + # Check if this is the first iteration of accumulating into this + # location + is_first_iteration = " && ".join([ + "{} == {}".format(iterators_inner[i], input_subset[axis][0]) + for i, axis in enumerate(axes) + ]) + if identity is not None: + # If this is the first iteration, set the previous value to be + # identity, otherwise read the value from the output location + prev_var = "{}_prev".format(output_memlet.data) + callsite_stream.write( + "{} {} = ({}) ? ({}) : ({});".format( + output_type, prev_var, is_first_iteration, identity, + out_var), sdfg, state_id, node) + callsite_stream.write( + "{} = {}({}, {});".format(out_var, reduction_cpp, prev_var, + in_var), sdfg, state_id, node) + else: + # If this is the first iteration, assign the value read from the + # input directly to the output + callsite_stream.write( + "{} = ({}) ? ({}) : {}({}, {});".format( + out_var, is_first_iteration, in_var, reduction_cpp, + out_var, in_var), sdfg, state_id, node) + + # Generate closing braces + for i in range(end_braces): + callsite_stream.write('}', sdfg, state_id, node) + if i == end_braces - 1 and all_axes_collapsed: + dst_expr = output_memlet.data + offset = cpp_offset_expr( + dst_data, + output_memlet.subset, + packed_veclen=output_memlet.veclen) + if offset: + dst_expr += " + " + offset + callsite_stream.write( + "dace::Write({}, {});".format(dst_expr, out_var), sdfg, + state_id, node) + + def _generate_Tasklet(self, sdfg, dfg, state_id, node, function_stream, + callsite_stream): + + # TODO: this is copied from the CPU-codegen, necessary to inject + # pragmas at the output memlets! Should consolidate. + + callsite_stream.write('{\n', sdfg, state_id, node) + + state_dfg = sdfg.nodes()[state_id] + + self._dispatcher.defined_vars.enter_scope(node) + + arrays = set() + for edge in dfg.in_edges(node): + u = edge.src + memlet = edge.data + + if edge.dst_conn: # Not (None or "") + if edge.dst_conn in arrays: # Disallow duplicates + raise SyntaxError('Duplicates found in memlets') + # Special case: code->code + if isinstance(edge.src, nodes.CodeNode): + shared_data_name = 's%d_n%d%s_n%d%s' % ( + state_id, dfg.node_id(edge.src), edge.src_conn, + dfg.node_id(edge.dst), edge.dst_conn) + + # Read variable from shared storage + callsite_stream.write( + 'const dace::vec<%s, %s>& %s = __%s;' % + (edge.data.data.dtype.ctype, sym2cpp(edge.data.veclen), + edge.dst_conn, shared_data_name), sdfg, state_id, + [edge.src, edge.dst]) + self._dispatcher.defined_vars.add(edge.dst_conn, + DefinedType.Scalar) + + else: + src_node = find_input_arraynode(state_dfg, edge) + + self._dispatcher.dispatch_copy( + src_node, node, edge, sdfg, state_dfg, state_id, + function_stream, callsite_stream) + + # Also define variables in the C++ unparser scope + self._cpu_codegen._locals.define(edge.dst_conn, -1, + self._cpu_codegen._ldepth + 1) + arrays.add(edge.dst_conn) + + callsite_stream.write('\n', sdfg, state_id, node) + + # Use outgoing edges to preallocate output local vars + for edge in dfg.out_edges(node): + v = edge.dst + memlet = edge.data + + if edge.src_conn: + if edge.src_conn in arrays: # Disallow duplicates + continue + # Special case: code->code + if isinstance(edge.dst, nodes.CodeNode): + callsite_stream.write( + 'dace::vec<%s, %s> %s;' % + (sdfg.arrays[memlet.data].dtype.ctype, + sym2cpp(memlet.veclen), edge.src_conn), sdfg, + state_id, [edge.src, edge.dst]) + self._dispatcher.defined_vars.add(edge.src_conn, + DefinedType.Scalar) + else: + dst_node = find_output_arraynode(state_dfg, edge) + + self._dispatcher.dispatch_copy( + node, dst_node, edge, sdfg, state_dfg, state_id, + function_stream, callsite_stream) + + # Also define variables in the C++ unparser scope + self._cpu_codegen._locals.define(edge.src_conn, -1, + self._cpu_codegen._ldepth + 1) + arrays.add(edge.src_conn) + + callsite_stream.write('\n ///////////////////\n', sdfg, state_id, + node) + + cpu.unparse_tasklet(sdfg, state_id, dfg, node, function_stream, + callsite_stream, self._cpu_codegen._locals, + self._cpu_codegen._ldepth) + + callsite_stream.write(' ///////////////////\n\n', sdfg, state_id, + node) + + # Process outgoing memlets + self._cpu_codegen.process_out_memlets( + sdfg, state_id, node, state_dfg, self._dispatcher, callsite_stream, + True, function_stream) + + for edge in state_dfg.out_edges(node): + datadesc = sdfg.arrays[edge.data.data] + if (isinstance(datadesc, dace.data.Array) and + (datadesc.storage == dace.types.StorageType.FPGA_Local + or datadesc.storage == dace.types.StorageType.FPGA_Registers) + and edge.data.wcr is None): + callsite_stream.write( + "#pragma HLS DEPENDENCE variable=__{} false".format( + edge.src_conn)) + + callsite_stream.write('}\n', sdfg, state_id, node) + + self._dispatcher.defined_vars.exit_scope(node) diff --git a/dace/codegen/tools/dacestub.cpp b/dace/codegen/tools/dacestub.cpp new file mode 100644 index 0000000000..e2a76b54ed --- /dev/null +++ b/dace/codegen/tools/dacestub.cpp @@ -0,0 +1,85 @@ +/** + * Stub library that can load other libraries for use in as DaCe programs +**/ + +#ifdef _WIN32 + #include + #define DACE_EXPORTED extern "C" __declspec(dllexport) +#else + #include + #define DACE_EXPORTED extern "C" +#endif + +// Workaround (see unload_library) +#include + +// Loads a library and returns a handle to it, or NULL if there was an error +// NOTE: On Windows, path must be given as a Unicode string (UTF-16, or +// ctypes.c_wchar_p) +DACE_EXPORTED void *load_library(const char *filename) { + if (!filename) + return nullptr; + + void *hLibrary = nullptr; + +#ifdef _WIN32 + hLibrary = (void *)LoadLibraryW((const wchar_t*)filename); +#else + hLibrary = dlopen(filename, RTLD_LOCAL | RTLD_NOW); +#endif + + return hLibrary; +} + +// Returns 1 if the library is already loaded, 0 if not, or -1 on error +DACE_EXPORTED int is_library_loaded(const char *filename) { + if (!filename) + return -1; + + void *hLibrary = nullptr; + +#ifdef _WIN32 + hLibrary = (void *)GetModuleHandleW((const wchar_t*)filename); +#else + hLibrary = dlopen(filename, RTLD_LOCAL | RTLD_NOW | RTLD_NOLOAD); +#endif + + if (hLibrary) + return 1; + return 0; +} + +// Loads a library function and returns a pointer, or NULL if it was not found +DACE_EXPORTED void *get_symbol(void *hLibrary, const char *symbol) { + if (!hLibrary || !symbol) + return nullptr; + + void *sym = nullptr; + +#ifdef _WIN32 + sym = GetProcAddress((HMODULE)hLibrary, symbol); +#else + sym = dlsym(hLibrary, symbol); +#endif + + return sym; +} + +// Loads a library and returns a handle to it, or NULL if there was an error +// NOTE: On Windows, path must be given as a Unicode string (UTF-16, or +// ctypes.c_wchar_p) +DACE_EXPORTED void unload_library(void *hLibrary) { + if (!hLibrary) + return; + + // Workaround so that OpenMP does not go ballistic when calling dlclose() + omp_get_max_threads(); + +#ifdef _WIN32 + FreeLibrary((HMODULE)hLibrary); +#else + dlclose(hLibrary); +#endif +} + + diff --git a/dace/codegen/tools/get_cuda_arch.cpp b/dace/codegen/tools/get_cuda_arch.cpp new file mode 100644 index 0000000000..88442f0729 --- /dev/null +++ b/dace/codegen/tools/get_cuda_arch.cpp @@ -0,0 +1,33 @@ +#include +#include +#include +#include +#include + + +int main(int argc, char **argv) +{ + int count; + if (cudaGetDeviceCount(&count) != cudaSuccess) + return 1; + + std::set architectures; + // Loop over all GPU architectures + for (int i = 0; i < count; ++i) + { + cudaDeviceProp prop; + if (cudaGetDeviceProperties(&prop, i) != cudaSuccess || + (prop.major == 99 && prop.minor == 99)) + continue; + std::stringstream ss; + ss << prop.major << prop.minor; + architectures.insert(ss.str()); + } + + // Print out architectures + for (std::set::iterator iter = architectures.begin(); + iter != architectures.end(); ++iter) + std::cout << *iter << " "; + + return 0; +} diff --git a/dace/config.py b/dace/config.py new file mode 100644 index 0000000000..e6415a4a7b --- /dev/null +++ b/dace/config.py @@ -0,0 +1,265 @@ +import os +import platform +import yaml + + +def _env2bool(envval): + """ Converts an arbitrary value to boolean. + @param envval: Arbitrary value. + @return: True if the input value matches a valid TRUE + value, or False otherwise. + """ + return str(envval).lower() in ['true', '1', 'y', 'yes', 'on'] + + +def _add_defaults(config, metadata): + """ Add defaults to configuration from metadata. + @return: True if configuration was modified, False otherwise. + """ + osname = platform.system() + modified = False + for k, v in metadata.items(): + # Recursive call for fields inside the dictionary + if v['type'] == 'dict': + if k not in config: + modified = True + config[k] = {} + modified |= _add_defaults(config[k], v['required']) + continue + # Empty list initialization (if no default is specified) + elif v['type'] == 'list': + if k not in config and 'default' not in v: + modified = True + config[k] = [] + continue + # Key does not exist in configuration, add default value + if k not in config: + modified = True + # Per-OS default + if 'default_' + osname in v: + config[k] = v['default_' + osname] + else: + config[k] = v['default'] + return modified + + +class Config(object): + """ Interface to the DaCe hierarchical configuration file. """ + + _config = {} + _config_metadata = {} + _cfg_filename = None + _metadata_filename = None + + @staticmethod + def cfg_filename(): + """ Returns the current configuration file path. """ + + return Config._cfg_filename + + @staticmethod + def initialize(): + """ Initializes configuration. + + B{Note:} This function runs automatically when the module + is loaded. + """ + + # If already initialized, skip + if Config._cfg_filename is not None: + return + + # Override default configuration file path + if 'DACE_CONFIG' in os.environ: + cfg_filename = os.environ['DACE_CONFIG'] + else: + home = os.path.expanduser("~") + cfg_filename = os.path.join(home, ".dace.conf") + + Config._cfg_filename = cfg_filename + + dace_path = os.path.dirname(os.path.abspath(__file__)) + Config._metadata_filename = os.path.join(dace_path, + 'config_schema.yml') + + # Load configuration schema (for validation and defaults) + Config.load_schema() + + if os.path.isfile(cfg_filename): + Config.load() + else: + # Load the defaults from metadata and save new conf file + Config._config = {} + _add_defaults(Config._config, Config._config_metadata['required']) + Config.save() + + @staticmethod + def load(filename=None): + """ Loads a configuration from an existing file. + @param filename: The file to load. If unspecified, + uses default configuration file. + """ + if filename is None: + filename = Config._cfg_filename + + # Read configuration file + with open(filename, 'r') as f: + Config._config = yaml.load(f.read()) + + # Add defaults from metadata + modified = _add_defaults(Config._config, + Config._config_metadata['required']) + if modified: # Update file if changed + Config.save() + + @staticmethod + def load_schema(filename=None): + """ Loads a configuration schema from an existing file. + @param filename: The file to load. If unspecified, + uses default schema file. + """ + if filename is None: + filename = Config._metadata_filename + with open(filename, 'r') as f: + Config._config_metadata = yaml.load(f.read()) + + @staticmethod + def save(path=None): + """ Saves the current configuration to a file. + @param path: The file to save to. If unspecified, + uses default configuration file. + """ + if path is None: + path = Config._cfg_filename + # Write configuration file + with open(path, 'w') as f: + yaml.dump(Config._config, f, default_flow_style=False) + + @staticmethod + def get_metadata(*key_hierarchy): + """ Returns the configuration specification of a given entry + from the schema. + @param key_hierarchy: A tuple of strings leading to the + configuration entry. + For example: ('a', 'b', 'c') would be + configuration entry c which is in the + path a->b. + @return: Configuration specification as a dictionary. + """ + # Traverse the key hierarchy + current_conf = Config._config_metadata + for key in key_hierarchy: + current_conf = current_conf['required'][key] + return current_conf + + @staticmethod + def get_default(*key_hierarchy): + """ Returns the default value of a given configuration entry. + Takes into accound current operating system. + @param key_hierarchy: A tuple of strings leading to the + configuration entry. + For example: ('a', 'b', 'c') would be + configuration entry c which is in the + path a->b. + @return: Default configuration value. + """ + # Traverse the key hierarchy + current_conf = Config._config_metadata + for key in key_hierarchy: + current_conf = current_conf['required'][key] + if 'default_' + platform.system() in current_conf: + return current_conf['default_' + platform.system()] + return current_conf['default'] + + @staticmethod + def get(*key_hierarchy): + """ Returns the current value of a given configuration entry. + @param key_hierarchy: A tuple of strings leading to the + configuration entry. + For example: ('a', 'b', 'c') would be + configuration entry c which is in the + path a->b. + @return: Configuration entry value. + """ + # Environment variable override + # NOTE: will only work if a specific key is accessed! + envvar = 'DACE_' + '_'.join(key_hierarchy) + if envvar in os.environ: + return os.environ[envvar] + + # Traverse the key hierarchy + current_conf = Config._config + for key in key_hierarchy: + current_conf = current_conf[key] + + return current_conf + + @staticmethod + def get_bool(*key_hierarchy): + """ Returns the current value of a given boolean configuration entry. + This specialization allows more string types to be converted to + boolean, e.g., due to environment variable overrides. + @param key_hierarchy: A tuple of strings leading to the + configuration entry. + For example: ('a', 'b', 'c') would be + configuration entry c which is in the + path a->b. + @return: Configuration entry value (as a boolean). + """ + res = Config.get(*key_hierarchy) + if isinstance(res, bool): + return res + return _env2bool(str(res)) + + @staticmethod + def append(*key_hierarchy, value=None, autosave=False): + """ Appends to the current value of a given configuration entry + and sets it. Example usage: + `Config.append('compiler', 'cpu', 'args', value='-fPIC')` + @param key_hierarchy: A tuple of strings leading to the + configuration entry. + For example: ('a', 'b', 'c') would be + configuration entry c which is in the + path a->b. + @param value: The value to append. + @param autosave: If True, saves the configuration to the file + after modification. + @return: Current configuration entry value. + """ + # Traverse the key hierarchy up until the next to last element + current_conf = Config._config + for key in key_hierarchy[:-1]: + current_conf = current_conf[key] + + current_conf[key_hierarchy[-1]] += value + if autosave: + Config.save() + + return current_conf[key_hierarchy[-1]] + + @staticmethod + def set(*key_hierarchy, value=None, autosave=False): + """ Sets the current value of a given configuration entry. + Example usage: + `Config.set('profiling', value=True)` + @param key_hierarchy: A tuple of strings leading to the + configuration entry. + For example: ('a', 'b', 'c') would be + configuration entry c which is in the + path a->b. + @param value: The value to set. + @param autosave: If True, saves the configuration to the file + after modification. + """ + # Traverse the key hierarchy up until the next to last element + current_conf = Config._config + for key in key_hierarchy[:-1]: + current_conf = current_conf[key] + + current_conf[key_hierarchy[-1]] = value + if autosave: + Config.save() + + +# Code that runs when the module is loaded +Config.initialize() diff --git a/dace/config_schema.yml b/dace/config_schema.yml new file mode 100644 index 0000000000..d0cdb5c552 --- /dev/null +++ b/dace/config_schema.yml @@ -0,0 +1,602 @@ +# Schema file for DaCe Preferences + +# Metadata fields for elements: +# type: any python type (dict, list, int, bool, float, str) +# title: short name to show in GUI +# description: tooltip to show in GUI +# required: required sub-fields (for dict fields) +# default: default value. Can be platform-specific (see below) +# default_: default value for platform (overrides default) +# template_vars: template variables to include when processing (str fields only) + +# Top-level element is a dictionary (record) +type: dict +title: General +description: DaCe Preferences +required: + ############################################# + # Categories + optimizer: + type: dict + title: Optimizer + description: Preferences of the SDFG Optimizer + required: + autospecialize: + type: bool + default: false + title: Auto-specialize symbols + description: > + Automatically specialize every SDFG to the symbol values + at call-time. Requires all symbols to be set. + + interface: + type: bool + default: dace.transformation.optimizer.SDFGOptimizer + title: SDFG Optimizer + description: > + SDFG optimization class to import and call automatically + on compilation. Defaults to the transformation CLI, empty + string or an invalid class name skips the process. + + visualize: + type: bool + default: false + title: Visualize SDFG + description: Open a GraphViz window after every transformation. + + savedots: + type: bool + default: false + title: Save dot files + description: Save GraphViz .dot files after every transformation. + + automatic_state_fusion: + type: bool + default: true + title: Automatic strict transformations + description: > + Automatically performs strict transformations + that are considered to be safe. + + detect_control_flow: + type: bool + default: true + title: Detect control flow from state transitions + description: > + Attempts to infer control flow constructs "if", + "for" and "while" from state transitions, allowing + code generators to generate appropriate code. + + renderer: + type: dict + title: Renderer + description: Preferences of the SDFG Renderer + required: + fulledges: + type: bool + default: false + title: Show full edges + description: > + If enabled, prints out the full edge labels (which may be + long due to complex indexing). + html5renderer: + type: bool + default: false + title: (EXPERIMENTAL) HTML5 Rendering Engine + description: > + If enabled, uses an HTML5-based renderer to display SDFGs. + This allows to visualize performance data, but is still experimental. + + compiler: + type: dict + title: Compiler + description: Preferences of the compiler + required: + use_cache: + type: bool + default: false + title: Use cache + description: > + If enabled, does not recompile code generated from SDFGs + if shared library (.so/.dll) file is present. + + library_extension: + type: str + default: so + default_Linux: so + default_Windows: dll + default_Darwin: dylib + title: Library extension + description: File extension of shared libraries. + + indentation_spaces: + type: int + default: 4 + title: Indentation width + description: > + Number of spaces used when indenting generated code. + + build_type: + type: str + default: Release + title: Build configuration + description: > + Configuration type for CMake build (can be Debug, Release, + RelWithDebInfo, or MinSizeRel). + + allow_shadowing: + type: str + default: false + title: Allow variable shadowing + description: > + Allowing shadowing of variables in the code (reduces + exceptions to warnings when shadowing is encountered). + + ############################################# + # CPU compiler + cpu: + type: dict + title: CPU + description: CPU compiler preferences + required: + executable: + type: str + default: g++ + default_Windows: cl + title: Compiler executable name + description: File path or name of compiler executable + + args: + type: str + title: Arguments + description: Compiler argument flags + default: '-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -ffast-math -Wno-unused-parameter -Wno-unused-label' + default_Windows: '/O2 /fp:fast /arch:AVX2 /D_USRDLL /D_WINDLL /D__restrict__=__restrict' + + additional_args: + type: str + title: Extra Arguments + description: Additional arguments provided by users + default: '' + + libs: + type: str + title: Additional libraries + description: Additional linked libraries required by target + default: '' + + ############################################# + # GPU (CUDA) compiler + cuda: + type: dict + title: GPU + description: GPU (CUDA) compiler preferences + required: + executable: + type: str + default: nvcc + title: Compiler executable name + description: File path or name of compiler executable + + args: + type: str + title: Arguments + description: Compiler argument flags + default: '-std=c++14 -Xcompiler -fPIC -O3 -Xcompiler -march=native --use_fast_math -Xcompiler -Wno-unused-parameter' + + cuda_arch: + type: str + title: Additional CUDA architectures + description: > + Additional CUDA architectures (separated by commas) + to compile GPU code for, excluding the current + architecture on the compiling machine. + default: '35' + + default_block_size: + type: str + title: Default thread-block size + description: > + Default thread-block size for CUDA kernels when + explicit GPU block maps are not defined. + default: '32,1,1' + + max_concurrent_streams: + type: int + title: Concurrent CUDA streams + description: > + Maximum number of concurrent CUDA streams to + generate. Special values: -1 only uses the + default stream, 0 uses infinite concurrent streams. + default: 0 + + additional_args: + type: str + title: Extra Arguments + description: Additional arguments provided by users + default: '' + + libs: + type: str + title: Additional libraries + description: Additional linked libraries required by target + default: '' + + ############################################# + # FPGA (Xilinx) compiler flags + xilinx: + type: dict + title: Xilinx + description: FPGA (Xilinx) compiler preferences + required: + + mode: + type: str + default: simulation + title: Compilation mode + description: Target of FPGA kernel build (simulation/software_emulation/hardware_emulation/hardware) + + executable: + type: str + default: xocc + title: SDAccel compiler executable path + description: File path or name of SDAccel binary (xocc) + + platform: + type: str + default: xilinx_vcu1525_dynamic_5_1 + title: Target platform for xocc + description: Platform name of SDAccel target. + + enable_debugging: + type: bool + default: false + title: Enable debugging for hardware kernels + description: > + Injects debugging cores on the interfaces of the + kernel, allowing fine-grained debugging of hardware + runs at the cost of additional resources. This is + always enabled for emulation runs. + + host_flags: + type: str + title: Host arguments + description: Extra host compiler argument flags + default: "-Wno-unknown-pragmas -Wno-unused-label" + + synthesis_flags: + type: str + title: Synthesis arguments + description: High-level synthesis C++ flags + default: "-std=c++11" + + build_flags: + type: str + title: Arguments + description: Kernel build (xocc) C++ flags + default: "" + + ############################################# + # MPI compiler + mpi: + type: dict + title: MPI + description: MPI compiler preferences + required: + executable: + type: str + default: mpicxx + title: Compiler executable name + description: File path or name of compiler executable + + ############################################# + # Linker + linker: + type: dict + title: Linker + description: Linker preferences + required: + executable: + type: str + default: g++ + default_Windows: cl + title: Linker executable name + description: File path or name of linker executable + + args: + type: str + title: Arguments + description: Linker argument flags + default: '' + + additional_args: + type: str + title: Extra Arguments + description: Additional arguments provided by users + default: '' + template_envvars: + - CUDA_PATH + + library_prefix: + type: str + title: Library argument prefix + description: > + Argument prefix to add before each added library. + default: '-l' + default_Windows: '' + + library_suffix: + type: str + title: Library argument suffix + description: > + Argument suffix to add after each added library. + default: '' + default_Windows: '.lib' + + execution: + type: dict + title: Execution + description: Binary execution preferences + required: + general: + type: dict + title: General + description: General execution preferences + required: + host: + type: str + default: localhost + title: Host + description: Hostname to use for execution + + workdir: + type: str + default: '/tmp/' + title: Working directory + description: Working directory on the remote host + + check_args: + type: bool + default: true + title: Check arguments + description: > + Do strict verification that arguments passed when + calling a DaCe program match the expected types. + + execcmd: + type: str + title: Command + description: > + Command to use to execute ${command} on ${host} + default: 'ssh ${host} ${command}' + template_vars: + - host + - command + + copycmd_r2l: + type: str + default: 'scp ${host}:${srcfile} ${dstfile}' + title: "Remote->Local copy command" + description: > + Command to use to copy ${srcfile} on ${host} to + the local ${dstfile}. + template_vars: + - host + - srcfile + - dstfile + + copycmd_l2r: + type: str + default: "scp ${srcfile} ${host}:${dstfile}" + title: "Local->Remote copy command" + description: > + Command to use to copy the local ${srcfile} to the + remote ${dstfile}. + template_vars: + - host + - srcfile + - dstfile + + repetitions: + type: int + default: 5 + title: "Repetitions per Run" + description: > + Number of repetitions to run for each click of the + Run button (median value will be reported in the + performance chart). + mpi: + type: dict + title: MPI + description: MPI execution preferences + required: + mpiexec: + type: str + default: 'mpirun -n ${num_procs} ${command}' + title: mpirun command + description: > + Command to use to execute MPI job ${command} with + ${num_procs} processes. + template_vars: + - num_procs + - command + + num_procs: + type: int + default: 4 + title: Number of processes + description: Number of MPI processes to use + diode: + type: dict + title: DIODE + description: DIODE GUI preferences + required: + layout: + type: dict + title: Layout + description: Window layout preferences + required: + window_width: + default: 800 + title: Window Width + type: float + description: Window width (in pixels) + + window_height: + default: 600 + title: Window Height + type: float + description: Window height (in pixels) + + window_maximized: + default: True + title: Window Maximized + type: bool + description: > + If True, DIODE starts with a maximized window + + toppane_height: + default: 20 + type: float + title: Top-Pane Height + description: > + Height of top pane in Optimizer view (in percentage) + + pypane_width: + default: 30 + title: Python Pane Width + type: float + description: > + Width of the Python Editor pane (in percentage) + + optpane_width: + default: 30 + title: Transformation Pane Width + type: float + description: > + Width of the Transformation pane (in percentage) + + codepane_width: + default: 30 + title: Generated Code Pane Width + type: float + description: > + Width of the Generated Code pane (in percentage) + + perfpane_width: + default: 30 + title: Performance Pane Width + type: float + description: > + Width of the Performance graph pane (in percentage) + + general: + type: dict + title: General + description: General DIODE Preferences + required: + + show_transfed: + type: bool + default: False + title: (EXPERIMENTAL) Show Transformation Editor + description: > + Show (or hide) the experimental transformation + editor. + + show_sdfged: + type: bool + default: False + title: (EXPERIMENTAL) Show SDFG Editor + description: > + Show (or hide) the experimental SDFG Editor. + + show_optgraph: + type: bool + default: False + title: Show Optimization Graph + description: > + Show available transformations as a graph. This is + discouraged as the optimization graph may be too + large to be useful. + + fonts: + type: dict + title: Fonts + description: Fonts used in editors + required: + python: + default: '' + title: Python + type: font + description: Font used to render Python code. + + codegen: + default: '' + title: Generated Code + type: font + description: Font used to render generated code. + + pated: + default: '' + title: Transformation Editor + type: font + description: Font used to render pattern match code. + + instrumentation: + type: dict + title: Instrumentation + description: Instrumentation preferences + required: + enable_papi: + type: bool + title: Enable PAPI + default: false + description: Enable instrumentation using PAPI + enable_vectorization_analysis: + type: bool + title: Enable vectorization check + default: false + description: > + Enables analysis of gcc vectorization information. Only gcc/g++ is supported. + enable_papi_counter_sanity_check: + type: bool + title: Counter sanity check + default: false + description: > + Enables a pre-run sanity check to minimize runtime failures + default_papi_counters: + type: str + title: Default PAPI counters + default: "['PAPI_TOT_INS', 'PAPI_TOT_CYC', 'PAPI_L2_TCM', 'PAPI_L3_TCM']" + description: > + Sets the default PAPI counter list, formatted as + a Python list of strings. + max_scope_depth: + type: int + title: Max scope depth + default: 5 + description: > + Sets the maximum depth of instrumentations in + map/consume scopes. Scopes that are deeper will not + be instrumented. + + ############################################# + # General settings + debugprint: + type: bool + default: true + title: Debug printing + description: Enable verbose printouts. + + profiling: + type: bool + default: false + title: Profiling + description: Enable profiling support. + + treps: + type: int + default: 100 + title: Profiling Repetitions + description: Number of times to run program for profiling. diff --git a/dace/data.py b/dace/data.py new file mode 100644 index 0000000000..2c0e6a47ac --- /dev/null +++ b/dace/data.py @@ -0,0 +1,496 @@ +import functools +import operator +import re +import copy as cp +import sympy as sp + +import dace +from dace.codegen import cppunparse +from dace import symbolic +from dace.properties import (Property, make_properties, ReferenceProperty, + ShapeProperty, SubsetProperty, SymbolicProperty, + TypeClassProperty, DebugInfoProperty, + CodeProperty) + + +def validate_name(name): + if not isinstance(name, str): + return False + if re.match(r'^[a-zA-Z_][a-zA-Z_0-9]*$', name) is None: + return False + return True + + +@make_properties +class Data(object): + """ Data type descriptors that can be used as references to memory. + Examples: Arrays, Streams, custom arrays (e.g., sparse matrices). + """ + + dtype = TypeClassProperty() + shape = ShapeProperty() + transient = Property(dtype=bool) + storage = Property( + dtype=dace.types.StorageType, + desc="Storage location", + enum=dace.types.StorageType, + default=dace.types.StorageType.Default, + from_string=lambda x: types.StorageType[x]) + location = Property( + dtype=str, # Dict[str, symbolic] + desc='Full storage location identifier (e.g., rank, GPU ID)', + default='') + toplevel = Property( + dtype=bool, desc="Allocate array outside of state", default=False) + debuginfo = DebugInfoProperty() + + def __init__(self, dtype, shape, transient, storage, location, toplevel, + debuginfo): + self.dtype = dtype + self.shape = shape + self.transient = transient + self.storage = storage + self.location = location + self.toplevel = toplevel + self.debuginfo = debuginfo + self._validate() + + def validate(self): + """ Validate the correctness of this object. + Raises an exception on error. """ + self._validate() + + # Validation of this class is in a separate function, so that this + # class can call `_validate()` without calling the subclasses' + # `validate` function. + def _validate(self): + if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, + symbolic.sympy.Basic)) for s in self.shape): + raise TypeError('Shape must be a list or tuple of integer values ' + 'or symbols') + return True + + def copy(self): + raise RuntimeError( + 'Data descriptors are unique and should not be copied') + + def is_equivalent(self, other): + """ Check for equivalence (shape and type) of two data descriptors. """ + raise NotImplementedError + + def signature(self, with_types=True, name=None): + """Returns a string for a C++ function signature (e.g., `int *A`). """ + raise NotImplementedError + + def __repr__(self): + return 'Abstract Data Container, DO NOT USE' + + +@make_properties +class Scalar(Data): + """ Data descriptor of a scalar value. """ + + allow_conflicts = Property(dtype=bool) + + def __init__(self, + dtype, + transient=False, + storage=dace.types.StorageType.Default, + allow_conflicts=False, + location='', + toplevel=False, + debuginfo=None): + self.allow_conflicts = allow_conflicts + shape = [1] + super(Scalar, self).__init__(dtype, shape, transient, storage, + location, toplevel, debuginfo) + + def __repr__(self): + return 'Scalar (dtype=%s)' % self.dtype + + def clone(self): + return Scalar(self.dtype, self.transient, self.storage, + self.allow_conflicts, self.location, self.toplevel, + self.debuginfo) + + @property + def strides(self): + return self.shape + + @property + def offset(self): + return [0] + + def is_equivalent(self, other): + if not isinstance(other, Scalar): + return False + if self.dtype != other.type: + return False + return True + + def signature(self, with_types=True, name=None): + if not with_types: return name + return str(self.dtype.ctype) + ' ' + name + + def sizes(self): + return None + + def covers_range(self, rng): + if len(rng) != 1: + return False + + rng = rng[0] + + try: + if (rng[1] - rng[0]) > rng[2]: + return False + except TypeError: # cannot determine truth value of Relational + pass + #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % ((rng[1] - rng[0]) > rng[2]), + # 'If this expression is false, please refine symbol definitions in the program.') + + return True + + +def set_materialize_func(obj, val): + """ Change the storage type of an array with a materialize function to + immaterial. + """ + if val is not None: + if (obj.storage != dace.types.StorageType.Default + and obj.storage != dace.types.StorageType.Immaterial): + raise ValueError("Immaterial array must have immaterial storage, " + "but has: {}".format(storage)) + obj.storage = dace.types.StorageType.Immaterial + obj._materialize_func = val + + +@make_properties +class Array(Data): + """ Array/constant descriptor (dimensions, type and other properties). """ + + # Properties + allow_conflicts = Property(dtype=bool) + # TODO: Should we use a Code property here? + materialize_func = Property( + dtype=str, allow_none=True, setter=set_materialize_func) + access_order = Property(dtype=tuple) + strides = Property(dtype=list) + offset = Property(dtype=list) + may_alias = Property( + dtype=bool, + default=False, + desc='This pointer may alias with other pointers in ' + 'the same function') + + def __init__(self, + dtype, + shape, + materialize_func=None, + transient=False, + allow_conflicts=False, + storage=dace.types.StorageType.Default, + location='', + access_order=None, + strides=None, + offset=None, + may_alias=False, + toplevel=False, + debuginfo=None): + + super(Array, self).__init__(dtype, shape, transient, storage, location, + toplevel, debuginfo) + + if shape is None: + raise IndexError('Shape must not be None') + + self.allow_conflicts = allow_conflicts + self.materialize_func = materialize_func + self.may_alias = may_alias + + if access_order is not None: + self.access_order = cp.copy(access_order) + else: + self.access_order = tuple(i for i in range(len(shape))) + + if strides is not None: + self.strides = cp.copy(strides) + else: + self.strides = cp.copy(list(shape)) + + if offset is not None: + self.offset = cp.copy(offset) + else: + self.offset = [0] * len(shape) + + self.validate() + + def __repr__(self): + return 'Array (dtype=%s, shape=%s)' % (self.dtype, self.shape) + + def clone(self): + return Array(self.dtype, self.shape, self.materialize_func, + self.transient, self.allow_conflicts, self.storage, + self.location, self.access_order, self.strides, + self.offset, self.may_alias, self.toplevel, + self.debuginfo) + + def validate(self): + super(Array, self).validate() + if len(self.access_order) != len(self.shape): + raise TypeError('Access order must be the same size as shape') + + if len(self.strides) != len(self.shape): + raise TypeError('Strides must be the same size as shape') + + if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, + symbolic.sympy.Basic)) + for s in self.strides): + raise TypeError('Strides must be a list or tuple of integer ' + 'values or symbols') + + if len(self.offset) != len(self.shape): + raise TypeError('Offset must be the same size as shape') + + def covers_range(self, rng): + if len(rng) != len(self.shape): + return False + + for s, (rb, re, rs) in zip(self.shape, rng): + # Shape has to be positive + if isinstance(s, sympy.Basic): + olds = s + if 'positive' in s.assumptions0: + s = sympy.Symbol(str(s), **s.assumptions0) + else: + s = sympy.Symbol(str(s), positive=True, **s.assumptions0) + if isinstance(rb, sympy.Basic): + rb = rb.subs({olds: s}) + if isinstance(re, sympy.Basic): + re = re.subs({olds: s}) + if isinstance(rs, sympy.Basic): + rs = rs.subs({olds: s}) + + try: + if rb < 0: # Negative offset + return False + except TypeError: # cannot determine truth value of Relational + pass + #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (rb > 0), + # 'If this expression is false, please refine symbol definitions in the program.') + try: + if re > s: # Beyond shape + return False + except TypeError: # cannot determine truth value of Relational + pass + #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (re < s), + # 'If this expression is false, please refine symbol definitions in the program.') + + return True + + # Checks for equivalent shape and type + def is_equivalent(self, other): + if not isinstance(other, Array): + return False + + # Test type + if self.dtype != other.type: + return False + + # Test dimensionality + if len(self.shape) != len(other.shape): + return False + + # Test shape + for dim, otherdim in zip(self.shape, other.shape): + # If both are symbols, ensure equality + if symbolic.issymbolic(dim) and symbolic.issymbolic(otherdim): + if dim != otherdim: + return False + + # If one is a symbol and the other is a constant + # make sure they are equivalent + elif symbolic.issymbolic(otherdim): + if symbolic.eval(otherdim) != dim: + return False + elif symbolic.issymbolic(dim): + if symbolic.eval(dim) != otherdim: + return False + else: + # Any other case (constant vs. constant), check for equality + if otherdim != dim: + return False + return True + + def signature(self, with_types=True, name=None): + arrname = name + if self.materialize_func is not None: + arrname = '/* ' + arrname + ' (immaterial) */' + if not with_types: + return 'nullptr' + + if not with_types: + return arrname + if self.may_alias: + return str(self.dtype.ctype) + ' *' + arrname + return str(self.dtype.ctype) + ' * __restrict__ ' + arrname + + def sizes(self): + return [ + d.name if isinstance(d, symbolic.symbol) else str(d) + for d in self.shape + ] + + +@make_properties +class Stream(Data): + """ Stream (or stream array) data descriptor. """ + + # Properties + strides = Property(dtype=list) + offset = Property(dtype=list) + buffer_size = Property(dtype=int, desc="Size of internal buffer.") + veclen = Property( + dtype=int, desc="Vector length. Memlets must adhere to this.") + + def __init__(self, + dtype, + veclen, + buffer_size, + shape=None, + transient=False, + storage=dace.types.StorageType.Default, + location='', + strides=None, + offset=None, + toplevel=False, + debuginfo=None): + + if shape is None: + shape = (1, ) + + self.veclen = veclen + self.buffer_size = buffer_size + + if strides is not None: + if len(strides) != len(shape): + raise TypeError('Strides must be the same size as shape') + self.strides = cp.copy(strides) + else: + self.strides = cp.copy(list(shape)) + + if offset is not None: + if len(offset) != len(shape): + raise TypeError('Offset must be the same size as shape') + self.offset = cp.copy(offset) + else: + self.offset = [0] * len(shape) + + super(Stream, self).__init__(dtype, shape, transient, storage, + location, toplevel, debuginfo) + + def __repr__(self): + return 'Stream (dtype=%s, shape=%s)' % (self.dtype, self.shape) + + def clone(self): + return Stream(self.dtype, self.veclen, self.buffer_size, self.shape, + self.transient, self.storage, self.location, + self.strides, self.offset, self.toplevel, self.debuginfo) + + # Checks for equivalent shape and type + def is_equivalent(self, other): + if not isinstance(other, Stream): + return False + + # Test type + if self.dtype != other.dtype: + return False + + # Test dimensionality + if len(self.shape) != len(other.shape): + return False + + # Test shape + for dim, otherdim in zip(self.shape, other.shape): + # If both are symbols, ensure equality + if symbolic.issymbolic(dim) and symbolic.issymbolic(otherdim): + if dim != otherdim: + return False + + # If one is a symbol and the other is a constant + # make sure they are equivalent + elif symbolic.issymbolic(otherdim): + if symbolic.eval(otherdim) != dim: + return False + elif symbolic.issymbolic(dim): + if symbolic.eval(dim) != otherdim: + return False + else: + # Any other case (constant vs. constant), check for equality + if otherdim != dim: + return False + return True + + def signature(self, with_types=True, name=None): + if not with_types: return name + if self.storage in [ + dace.types.StorageType.GPU_Global, + dace.types.StorageType.GPU_Shared, + dace.types.StorageType.GPU_Stack + ]: + return 'dace::GPUStream<%s, %s> %s' % ( + str(self.dtype.ctype), 'true' + if sp.log(self.buffer_size, 2).is_Integer else 'false', name) + + return 'dace::Stream<%s> %s' % (str(self.dtype.ctype), name) + + def sizes(self): + return [ + d.name if isinstance(d, symbolic.symbol) else str(d) + for d in self.shape + ] + + def size_string(self): + return (" * ".join([ + cppunparse.pyexpr2cpp(dace.symbolic.symstr(s)) + for s in self.strides + ])) + + def is_stream_array(self): + return functools.reduce(lambda a, b: a * b, self.strides) != 1 + + def covers_range(self, rng): + if len(rng) != len(self.shape): + return False + + for s, (rb, re, rs) in zip(self.shape, rng): + # Shape has to be positive + if isinstance(s, sympy.Basic): + olds = s + if 'positive' in s.assumptions0: + s = sympy.Symbol(str(s), **s.assumptions0) + else: + s = sympy.Symbol(str(s), positive=True, **s.assumptions0) + if isinstance(rb, sympy.Basic): + rb = rb.subs({olds: s}) + if isinstance(re, sympy.Basic): + re = re.subs({olds: s}) + if isinstance(rs, sympy.Basic): + rs = rs.subs({olds: s}) + + try: + if rb < 0: # Negative offset + return False + except TypeError: # cannot determine truth value of Relational + pass + #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (rb > 0), + # 'If this expression is false, please refine symbol definitions in the program.') + try: + if re > s: # Beyond shape + return False + except TypeError: # cannot determine truth value of Relational + pass + #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (re < s), + # 'If this expression is false, please refine symbol definitions in the program.') + + return True diff --git a/dace/external/cub b/dace/external/cub new file mode 160000 index 0000000000..c3cceac115 --- /dev/null +++ b/dace/external/cub @@ -0,0 +1 @@ +Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304 diff --git a/dace/external/hlslib b/dace/external/hlslib new file mode 160000 index 0000000000..628cd40a4a --- /dev/null +++ b/dace/external/hlslib @@ -0,0 +1 @@ +Subproject commit 628cd40a4ac5fe5dd2799030398fcb7a8072252c diff --git a/dace/external/moodycamel b/dace/external/moodycamel new file mode 160000 index 0000000000..dea078cf5b --- /dev/null +++ b/dace/external/moodycamel @@ -0,0 +1 @@ +Subproject commit dea078cf5b6e742cd67a0d725e36f872feca4de4 diff --git a/dace/frontend/__init__.py b/dace/frontend/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dace/frontend/common/__init__.py b/dace/frontend/common/__init__.py new file mode 100644 index 0000000000..343f8cadd9 --- /dev/null +++ b/dace/frontend/common/__init__.py @@ -0,0 +1,5 @@ +from .op_impl import matrix_multiplication, matrix_multiplication_s +from .op_impl import scalar_array_multiplication, scalar_array_multiplication_s +from .op_impl import constant_array_multiplication +from .op_impl import matrix_transpose, matrix_transpose_s +from .op_impl import matrix_pointwise_op diff --git a/dace/frontend/common/op_impl.py b/dace/frontend/common/op_impl.py new file mode 100644 index 0000000000..bf7e617470 --- /dev/null +++ b/dace/frontend/common/op_impl.py @@ -0,0 +1,1731 @@ +''' DaCe SDFG linear algebra operation library. ''' + +import copy +import dace +import dace.sdfg as sd +import dace.subsets as sbs +from dace import symbolic +import typing + +State = dace.sdfg.SDFGState +Shape = typing.List[typing.Union[int, dace.symbol]] +Index = typing.List[typing.Union[int, str, dace.symbol]] +Node = dace.graph.nodes.Node +DNode = dace.graph.nodes.AccessNode + +# TODO: Most of the external operations here emit Z (complex double) ops, fix + + +# TODO: Refactor to use GPUTransformLocalStorage? +def gpu_transform_tasklet(sdfg, graph, tasklet_node): + """ Transforms a tasklet to run on the GPU. Adapted from + `GPUTransformLocalStorage`. + @see: dace.transformation.dataflow.GPUTransformLocalStorage + """ + cnode = tasklet_node + exit_nodes = [tasklet_node] + + gpu_storage_types = [ + dace.types.StorageType.GPU_Global, dace.types.StorageType.GPU_Shared, + dace.types.StorageType.GPU_Stack + ] + + ####################################################### + # Add GPU copies of CPU arrays (i.e., not already on GPU) + + # First, understand which arrays to clone + all_out_edges = [] + for enode in exit_nodes: + all_out_edges.extend(list(graph.out_edges(enode))) + in_arrays_to_clone = set() + out_arrays_to_clone = set() + for e in graph.in_edges(cnode): + data_node = sd.find_input_arraynode(graph, e) + if data_node.desc(sdfg).storage not in gpu_storage_types: + in_arrays_to_clone.add((data_node, e.data)) + for e in all_out_edges: + data_node = sd.find_output_arraynode(graph, e) + if data_node.desc(sdfg).storage not in gpu_storage_types: + out_arrays_to_clone.add((data_node, e.data)) + + # Second, create a GPU clone of each array + # TODO: Overapproximate union of memlets + cloned_arrays = {} + in_cloned_arraynodes = {} + out_cloned_arraynodes = {} + for array_node, memlet in in_arrays_to_clone: + array = array_node.desc(sdfg) + cloned_name = 'gpu_' + array_node.data + for i, r in enumerate(memlet.bounding_box_size()): + size = symbolic.overapproximate(r) + try: + if int(size) == 1: + suffix = [] + for c in str(memlet.subset[i][0]): + if c.isalpha() or c.isdigit() or c == '_': + suffix.append(c) + elif c == '+': + suffix.append('p') + elif c == '-': + suffix.append('m') + elif c == '*': + suffix.append('t') + elif c == '/': + suffix.append('d') + cloned_name += '_' + ''.join(suffix) + except: + continue + if cloned_name in sdfg.arrays.keys(): + cloned_array = sdfg.arrays[cloned_name] + elif array_node.data in cloned_arrays: + cloned_array = cloned_arrays[array_node.data] + else: + full_shape = [] + for r in memlet.bounding_box_size(): + size = symbolic.overapproximate(r) + try: + full_shape.append(int(size)) + except: + full_shape.append(size) + actual_dims = [ + idx for idx, r in enumerate(full_shape) + if not (isinstance(r, int) and r == 1) + ] + if len(actual_dims) == 0: # abort + actual_dims = [len(full_shape) - 1] + if isinstance(array, dace.data.Scalar): + cloned_array = sdfg.add_array( + name=cloned_name, + shape=[1], + dtype=array.dtype, + transient=True, + storage=dace.types.StorageType.GPU_Global) + else: + cloned_array = sdfg.add_array( + name=cloned_name, + shape=[full_shape[d] for d in actual_dims], + dtype=array.dtype, + materialize_func=array.materialize_func, + transient=True, + storage=dace.types.StorageType.GPU_Global, + allow_conflicts=array.allow_conflicts, + access_order=tuple( + [array.access_order[d] for d in actual_dims]), + strides=[array.strides[d] for d in actual_dims], + offset=[array.offset[d] for d in actual_dims]) + cloned_arrays[array_node.data] = cloned_name + cloned_node = type(array_node)(cloned_name) + + in_cloned_arraynodes[array_node.data] = cloned_node + for array_node, memlet in out_arrays_to_clone: + array = array_node.desc(sdfg) + cloned_name = 'gpu_' + array_node.data + for i, r in enumerate(memlet.bounding_box_size()): + size = symbolic.overapproximate(r) + try: + if int(size) == 1: + suffix = [] + for c in str(memlet.subset[i][0]): + if c.isalpha() or c.isdigit() or c == '_': + suffix.append(c) + elif c == '+': + suffix.append('p') + elif c == '-': + suffix.append('m') + elif c == '*': + suffix.append('t') + elif c == '/': + suffix.append('d') + cloned_name += '_' + ''.join(suffix) + except: + continue + if cloned_name in sdfg.arrays.keys(): + cloned_array = sdfg.arrays[cloned_name] + elif array_node.data in cloned_arrays: + cloned_array = cloned_arrays[array_node.data] + else: + full_shape = [] + for r in memlet.bounding_box_size(): + size = symbolic.overapproximate(r) + try: + full_shape.append(int(size)) + except: + full_shape.append(size) + actual_dims = [ + idx for idx, r in enumerate(full_shape) + if not (isinstance(r, int) and r == 1) + ] + if len(actual_dims) == 0: # abort + actual_dims = [len(full_shape) - 1] + if isinstance(array, dace.data.Scalar): + cloned_array = sdfg.add_array( + name=cloned_name, + shape=[1], + dtype=array.dtype, + transient=True, + storage=dace.types.StorageType.GPU_Global) + else: + cloned_array = sdfg.add_array( + name=cloned_name, + shape=[full_shape[d] for d in actual_dims], + dtype=array.dtype, + materialize_func=array.materialize_func, + transient=True, + storage=dace.types.StorageType.GPU_Global, + allow_conflicts=array.allow_conflicts, + access_order=tuple( + [array.access_order[d] for d in actual_dims]), + strides=[array.strides[d] for d in actual_dims], + offset=[array.offset[d] for d in actual_dims]) + cloned_arrays[array_node.data] = cloned_name + cloned_node = type(array_node)(cloned_name) + cloned_node.setzero = True + + out_cloned_arraynodes[array_node.data] = cloned_node + + # Third, connect the cloned arrays to the originals + for array_name, node in in_cloned_arraynodes.items(): + graph.add_node(node) + is_scalar = isinstance(sdfg.arrays[array_name], dace.data.Scalar) + for edge in graph.in_edges(cnode): + if edge.data.data == array_name: + graph.remove_edge(edge) + newmemlet = copy.deepcopy(edge.data) + newmemlet.data = node.data + + if is_scalar: + newmemlet.subset = sbs.Indices([0]) + else: + offset = [] + lost_dims = [] + lost_ranges = [] + newsubset = [None] * len(edge.data.subset) + for ind, r in enumerate(edge.data.subset): + offset.append(r[0]) + if isinstance(edge.data.subset[ind], tuple): + begin = edge.data.subset[ind][0] - r[0] + end = edge.data.subset[ind][1] - r[0] + step = edge.data.subset[ind][2] + if begin == end: + lost_dims.append(ind) + lost_ranges.append((begin, end, step)) + else: + newsubset[ind] = (begin, end, step) + else: + newsubset[ind] -= r[0] + if len(lost_dims) == len(edge.data.subset): + newmemlet.subset = type( + edge.data.subset)([lost_ranges[-1]]) + else: + newmemlet.subset = type(edge.data.subset)( + [r for r in newsubset if r is not None]) + + graph.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, + newmemlet) + + edge.data.other_subset = newmemlet.subset + graph.add_edge(edge.src, None, node, None, edge.data) + for array_name, node in out_cloned_arraynodes.items(): + graph.add_node(node) + is_scalar = isinstance(sdfg.arrays[array_name], dace.data.Scalar) + for edge in all_out_edges: + if edge.data.data == array_name: + graph.remove_edge(edge) + newmemlet = copy.deepcopy(edge.data) + newmemlet.data = node.data + + if is_scalar: + newmemlet.subset = sbs.Indices([0]) + else: + offset = [] + lost_dims = [] + lost_ranges = [] + newsubset = [None] * len(edge.data.subset) + for ind, r in enumerate(edge.data.subset): + offset.append(r[0]) + if isinstance(edge.data.subset[ind], tuple): + begin = edge.data.subset[ind][0] - r[0] + end = edge.data.subset[ind][1] - r[0] + step = edge.data.subset[ind][2] + if begin == end: + lost_dims.append(ind) + lost_ranges.append((begin, end, step)) + else: + newsubset[ind] = (begin, end, step) + else: + newsubset[ind] -= r[0] + if len(lost_dims) == len(edge.data.subset): + newmemlet.subset = type( + edge.data.subset)([lost_ranges[-1]]) + else: + newmemlet.subset = type(edge.data.subset)( + [r for r in newsubset if r is not None]) + + graph.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, + newmemlet) + + edge.data.data = node.data + edge.data.other_subset = edge.data.subset + edge.data.subset = newmemlet.subset + graph.add_edge(node, None, edge.dst, None, edge.data) + + +class ValidationError(Exception): + """ An exception raised when inputs are not validated in SDFG library + calls. """ + + def __init__(self, message): + super().__init__(message) + + +def validate_matrix_multiplication( + A_shape: Shape, + B_shape: Shape, + C_shape: Shape, + A_index: Index = None, + B_index: Index = None, + C_index: Index = None +) -> ((str, str, str), (str, str, str), (str, str, str), (str, str, str)): + """ Validates a matrix multiplication operation, based on the shapes and + indices of the arrays involved. Returns the ranges of the maps and + memlets at all levels as strings. + """ + + # Validate input + if len(A_shape) < 2: + raise ValidationError( + 'Array A has less than 2 dimensions: {}'.format(A_shape)) + A_mm_shape = A_shape[-2:] + if len(B_shape) < 2: + raise ValidationError( + 'Array B has less than 2 dimensions: {}'.format(B_shape)) + B_mm_shape = B_shape[-2:] + if A_mm_shape[-1] != B_mm_shape[0]: + raise ValidationError( + 'N-dimension mismatch between arrays A and B: {} != {}'.format( + A_mm_shape[-1], B_mm_shape[0])) + + # Dimension sizes and ranges + M = A_mm_shape[0] + N = A_mm_shape[-1] + K = B_mm_shape[-1] + M_range = '0:{}'.format(M) + N_range = '0:{}'.format(N) + K_range = '0:{}'.format(K) + + # Validate slices and set input array access ranges + A_outer_range = '{}, {}'.format(M_range, N_range) + A_middle_range = '{}, ik'.format(M_range) + A_inner_range = 'ii, ik' + if len(A_shape) > 2: + if A_index is None or len(A_index) != len(A_shape) - 2: + raise ValidationError( + 'Invalid slice {} for array A with dimensions {}'.format( + A_index, A_shape)) + A_index = [str(idx) for idx in A_index] + A_outer_range = '{}, {}'.format(', '.join(A_index), A_outer_range) + A_middle_range = '{}, {}'.format(', '.join(A_index), A_middle_range) + A_inner_range = '{}, {}'.format(', '.join(A_index), A_inner_range) + B_outer_range = '{}, {}'.format(N_range, K_range) + B_middle_range = 'ik, {}'.format(K_range) + B_inner_range = 'ik, ij' + if len(B_shape) > 2: + if B_index is None or len(B_index) != len(B_shape) - 2: + raise ValidationError( + 'Invalid slice {} for array B with dimensions {}'.format( + B_index, B_shape)) + B_index = [str(idx) for idx in B_index] + B_outer_range = '{}, {}'.format(', '.join(B_index), B_outer_range) + B_middle_range = '{}, {}'.format(', '.join(B_index), B_middle_range) + B_inner_range = '{}, {}'.format(', '.join(B_index), B_inner_range) + + # Validate output + C_mm_shape = [M, K] + if len(C_shape) < 2: + raise ValidationError( + 'Array C has less than 2 dimensions: {}'.format(C_shape)) + if list(C_shape[-2:]) != C_mm_shape: + raise ValidationError( + 'Shape mismatch in array C: expected {}, but got {}'.format( + C_mm_shape, C_shape[-2:])) + C_outer_range = '{}, {}'.format(M_range, K_range) + C_middle_range = '{}, {}'.format(M_range, K_range) + C_inner_range = 'ii, ij' + if len(C_shape) > 2: + if C_index is None or len(C_index) != len(C_shape) - 2: + raise ValidationError( + 'Invalid slice {} for array C with dimensions {}'.format( + C_index, C_shape)) + C_index = [str(idx) for idx in C_index] + C_outer_range = '{}, {}'.format(', '.join(C_index), C_outer_range) + C_middle_range = '{}, {}'.format(', '.join(C_index), C_middle_range) + C_inner_range = '{}, {}'.format(', '.join(C_index), C_inner_range) + + return ((M_range, N_range, K_range), (A_outer_range, A_middle_range, + A_inner_range), + (B_outer_range, B_middle_range, + B_inner_range), (C_outer_range, C_middle_range, C_inner_range)) + + +def matrix_multiplication(state: State, + A_src: Node, + A_node: DNode, + B_src: Node, + B_node: DNode, + C_dst: Node, + C_node: DNode, + accumulate: bool = False, + interchange: bool = True, + A_index: Index = None, + B_index: Index = None, + C_index: Index = None, + label: str = None): + """ Adds a matrix multiplication operation to an existing SDFG state. + @param A_src: The source node from which the memlet of matrix A is + connected. + @param A_node: The Access Node for matrix A. + @param B_src: The source node from which the memlet of matrix B is + connected. + @param B_node: The Access Node for matrix B. + @param C_dst: The destination node to which the memlet of matrix C is + connected. + @param C_node: The Access Node for matrix C. + @param accumulate: Whether to accumulate to C or store to it. + @param interchange: If True, interchanges the multiplication maps for + performance (in some cases). + @param A_index: Slice of matrix A to use for multiplication. + @param B_index: Slice of matrix B to use for multiplication. + @param C_index: Slice of matrix C to use for multiplication. + @param label: Optional label for the maps and tasklet. + """ + + # Validate input + sdfg = state.parent + map_ranges, A_ranges, B_ranges, C_ranges = validate_matrix_multiplication( + A_node.desc(sdfg).shape, + B_node.desc(sdfg).shape, + C_node.desc(sdfg).shape, A_index, B_index, C_index) + + # Extract ranges + M_range, N_range, K_range = map_ranges + A_outer_range, A_middle_range, A_inner_range = A_ranges + B_outer_range, B_middle_range, B_inner_range = B_ranges + C_outer_range, C_middle_range, C_inner_range = C_ranges + + # Set label + if label is None: + label = state.label + + # Create maps/tasklet + k_entry, k_exit = state.add_map( + name=label + '_' + 'k_map', + ndrange=dict(ik=N_range), + schedule=dace.types.ScheduleType.Sequential) + k_entry.in_connectors = {'IN_1', 'IN_2'} + k_entry.out_connectors = {'OUT_1', 'OUT_2'} + k_exit.in_connectors = {'IN_1'} + k_exit.out_connectors = {'OUT_1'} + ij_entry, ij_exit = state.add_map( + name=label + '_' + 'ij_map', ndrange=dict(ii=M_range, ij=K_range)) + tasklet = state.add_tasklet( + name=label + '_' + 'tasklet', + inputs={'a', 'b'}, + outputs={'c'}, + code='c = a * b') + ij_entry.in_connectors = {'IN_1', 'IN_2'} + ij_entry.out_connectors = {'OUT_1', 'OUT_2'} + ij_exit.in_connectors = {'IN_1'} + ij_exit.out_connectors = {'OUT_1'} + + # Add edges + if interchange: + state.add_edge(A_src, None, k_entry, 'IN_1', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(B_src, None, k_entry, 'IN_2', + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(k_entry, 'OUT_1', ij_entry, 'IN_1', + dace.Memlet.simple(A_node, A_middle_range)) + state.add_edge(k_entry, 'OUT_2', ij_entry, 'IN_2', + dace.Memlet.simple(B_node, B_middle_range)) + state.add_edge(ij_entry, 'OUT_1', tasklet, 'a', + dace.Memlet.simple(A_node, A_inner_range)) + state.add_edge(ij_entry, 'OUT_2', tasklet, 'b', + dace.Memlet.simple(B_node, B_inner_range)) + wcr = 0 + if accumulate: + wcr = None + state.add_edge( + tasklet, 'c', ij_exit, 'IN_1', + dace.Memlet.simple( + C_node, + C_inner_range, + wcr_str='lambda x, y: x + y', + wcr_identity=wcr, + wcr_conflict=False)) + state.add_edge(ij_exit, 'OUT_1', k_exit, 'IN_1', + dace.Memlet.simple(C_node, C_middle_range)) + state.add_edge(k_exit, 'OUT_1', C_dst, None, + dace.Memlet.simple(C_node, C_outer_range)) + else: + state.add_edge(A_src, None, ij_entry, 'IN_1', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(B_src, None, ij_entry, 'IN_2', + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(ij_entry, 'OUT_1', k_entry, 'IN_1', + dace.Memlet.simple(A_node, A_middle_range)) + state.add_edge(ij_entry, 'OUT_2', k_entry, 'IN_2', + dace.Memlet.simple(B_node, B_middle_range)) + state.add_edge(k_entry, 'OUT_1', tasklet, 'a', + dace.Memlet.simple(A_node, A_inner_range)) + state.add_edge(k_entry, 'OUT_2', tasklet, 'b', + dace.Memlet.simple(B_node, B_inner_range)) + wcr = 0 + if accumulate: + wcr = None + state.add_edge( + tasklet, 'c', k_exit, 'IN_1', + dace.Memlet.simple( + C_node, + C_inner_range, + wcr_str='lambda x, y: x + y', + wcr_identity=wcr, + wcr_conflict=False)) + state.add_edge(k_exit, 'OUT_1', ij_exit, 'IN_1', + dace.Memlet.simple(C_node, C_middle_range)) + state.add_edge(ij_exit, 'OUT_1', C_dst, None, + dace.Memlet.simple(C_node, C_outer_range)) + + +def matrix_multiplication_cublas(state: State, + A_src: Node, + A_node: DNode, + B_src: Node, + B_node: DNode, + C_dst: Node, + C_node: DNode, + accumulate: bool = False, + interchange: bool = True, + alpha: str = 'const_pone', + beta: str = 'const_zero', + A_index: Index = None, + B_index: Index = None, + C_index: Index = None, + label: str = None): + """ Adds a matrix multiplication operation to an existing SDFG state, + using CUBLAS as the implementation. + @param A_src: The source node from which the memlet of matrix A is + connected. + @param A_node: The Access Node for matrix A. + @param B_src: The source node from which the memlet of matrix B is + connected. + @param B_node: The Access Node for matrix B. + @param C_dst: The destination node to which the memlet of matrix C is + connected. + @param C_node: The Access Node for matrix C. + @param accumulate: Whether to accumulate to C or store to it. + @param interchange: If True, interchanges the multiplication maps for + performance (in some cases). + @param alpha: Alpha value for GEMM. + @param beta: Beta value for GEMM. + @param A_index: Slice of matrix A to use for multiplication. + @param B_index: Slice of matrix B to use for multiplication. + @param C_index: Slice of matrix C to use for multiplication. + @param label: Optional label for the maps and tasklet. + """ + + # Validate input + sdfg = state.parent + map_ranges, A_ranges, B_ranges, C_ranges = validate_matrix_multiplication( + A_node.desc(sdfg).shape, + B_node.desc(sdfg).shape, + C_node.desc(sdfg).shape, A_index, B_index, C_index) + + # Extract ranges + M_range, N_range, K_range = map_ranges + A_outer_range, A_middle_range, A_inner_range = A_ranges + B_outer_range, B_middle_range, B_inner_range = B_ranges + C_outer_range, C_middle_range, C_inner_range = C_ranges + + # Set label + if label is None: + label = state.label + + # Create tasklet + tasklet = state.add_tasklet( + name=label + '_' + 'tasklet', + inputs={'a', 'b'}, + outputs={'c'}, + code=''' + //cuDoubleComplex alpha = make_cuDoubleComplex(1, 0); + //cuDoubleComplex beta = make_cuDoubleComplex(0, 0); + cublasSetStream(handle, __dace_current_stream); + cublasStatus_t status = cublasZgemm( + handle, + CUBLAS_OP_N, CUBLAS_OP_N, + bsize, bsize, bsize, + const_pone, + (cuDoubleComplex*)b, bsize, + (cuDoubleComplex*)a, bsize, + const_zero, + (cuDoubleComplex*)c, bsize + ); + ''', # cuBLAS is column-major, so we switch the arguments + language=dace.types.Language.CPP) + + state.add_edge(A_src, None, tasklet, 'a', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(B_src, None, tasklet, 'b', + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(tasklet, 'c', C_dst, None, + dace.Memlet.simple(C_node, C_outer_range)) + + gpu_transform_tasklet(sdfg, state, tasklet) + + +def matrix_multiplication_cublas_v2(state: State, + A_src: Node, + A_node: DNode, + B_src: Node, + B_node: DNode, + C_src: Node, + C_src_node: DNode, + C_dst: Node, + C_dst_node: DNode, + accumulate: bool = False, + interchange: bool = True, + alpha: str = 'const_pone', + beta: str = 'const_zero', + A_index: Index = None, + B_index: Index = None, + C_index: Index = None, + label: str = None): + """ Adds a matrix multiplication operation to an existing SDFG state, + using CUBLAS as the implementation, and providing a separate source + and destination nodes for the output matrix. + @param A_src: The source node from which the memlet of matrix A is + connected. + @param A_node: The Access Node for matrix A. + @param B_src: The source node from which the memlet of matrix B is + connected. + @param B_node: The Access Node for matrix B. + @param C_src: The node from which the memlet of matrix C is + connected into the multiplication. + @param C_src_node: The input Access Node for matrix C. + @param C_dst: The node to which the memlet of matrix C is + connected out of the multiplication. + @param C_dst_node: The output Access Node for matrix C. + @param accumulate: Whether to accumulate to C or store to it. + @param interchange: If True, interchanges the multiplication maps for + performance (in some cases). + @param alpha: Alpha value for GEMM. + @param beta: Beta value for GEMM. + @param A_index: Slice of matrix A to use for multiplication. + @param B_index: Slice of matrix B to use for multiplication. + @param C_index: Slice of matrix C to use for multiplication. + @param label: Optional label for the maps and tasklet. + """ + + # Validate input + sdfg = state.parent + map_ranges, A_ranges, B_ranges, C_ranges = validate_matrix_multiplication( + A_node.desc(sdfg).shape, + B_node.desc(sdfg).shape, + C_src_node.desc(sdfg).shape, A_index, B_index, C_index) + + # Extract ranges + M_range, N_range, K_range = map_ranges + A_outer_range, A_middle_range, A_inner_range = A_ranges + B_outer_range, B_middle_range, B_inner_range = B_ranges + C_outer_range, C_middle_range, C_inner_range = C_ranges + + # Set label + if label is None: + label = state.label + + # Create tasklet + tasklet = state.add_tasklet( + name=label + '_' + 'tasklet', + inputs={'a', 'b', 'cin'}, + outputs={'c'}, + code=''' + //cuDoubleComplex alpha = make_cuDoubleComplex(1, 0); + //cuDoubleComplex beta = make_cuDoubleComplex(0, 0); + cublasSetStream(handle, __dace_current_stream); + cublasStatus_t status = cublasZgemm( + handle, + CUBLAS_OP_N, CUBLAS_OP_N, + bsize, bsize, bsize, + {alpha}, + (cuDoubleComplex*)b, bsize, + (cuDoubleComplex*)a, bsize, + {beta}, + (cuDoubleComplex*)c, bsize + ); + '''.format( + alpha=alpha, + beta=beta), # cuBLAS is column-major, so we switch the arguments + language=dace.types.Language.CPP) + + state.add_edge(A_src, None, tasklet, 'a', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(B_src, None, tasklet, 'b', + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(C_src, None, tasklet, 'cin', + dace.Memlet.simple(C_src_node, C_outer_range)) + state.add_edge(tasklet, 'c', C_dst, None, + dace.Memlet.simple(C_dst_node, C_outer_range)) + + gpu_transform_tasklet(sdfg, state, tasklet) + + +def matrix_multiplication_mkl(state: State, + A_src: Node, + A_node: DNode, + B_src: Node, + B_node: DNode, + C_dst: Node, + C_node: DNode, + accumulate: bool = False, + interchange: bool = True, + A_index: Index = None, + B_index: Index = None, + C_index: Index = None, + label: str = None): + """ Adds a matrix multiplication operation to an existing SDFG state, + using MKL as the implementation. + @param A_src: The source node from which the memlet of matrix A is + connected. + @param A_node: The Access Node for matrix A. + @param B_src: The source node from which the memlet of matrix B is + connected. + @param B_node: The Access Node for matrix B. + @param C_dst: The destination node to which the memlet of matrix C is + connected. + @param C_node: The Access Node for matrix C. + @param accumulate: Whether to accumulate to C or store to it. + @param interchange: If True, interchanges the multiplication maps for + performance (in some cases). + @param A_index: Slice of matrix A to use for multiplication. + @param B_index: Slice of matrix B to use for multiplication. + @param C_index: Slice of matrix C to use for multiplication. + @param label: Optional label for the maps and tasklet. + """ + + # Validate input + sdfg = state.parent + map_ranges, A_ranges, B_ranges, C_ranges = validate_matrix_multiplication( + A_node.desc(sdfg).shape, + B_node.desc(sdfg).shape, + C_node.desc(sdfg).shape, A_index, B_index, C_index) + + # Extract ranges + M = A_node.desc(sdfg).shape[-2] + N = A_node.desc(sdfg).shape[-1] + K = B_node.desc(sdfg).shape[-1] + M_range, N_range, K_range = map_ranges + A_outer_range, A_middle_range, A_inner_range = A_ranges + B_outer_range, B_middle_range, B_inner_range = B_ranges + C_outer_range, C_middle_range, C_inner_range = C_ranges + + # Set label + if label is None: + label = state.label + + # Create tasklet + tasklet = state.add_tasklet( + name=label + '_' + 'tasklet', + inputs={'a', 'b'}, + outputs={'c'}, + code=''' + std::complex alpha(1, 0); + std::complex beta(0, 0); + char opa = 'N'; + char opb = 'N'; + zgemm( + &opa, &opb, + &{m}, &{n}, &{k}, + (MKL_Complex16*)&alpha, + (MKL_Complex16*)a, &{m}, + (MKL_Complex16*)b, &{n}, + (MKL_Complex16*)&beta, + (MKL_Complex16*)c, &{m} + ); + '''.format(m=M, n=N, k=K), + language=dace.types.Language.CPP) + + state.add_edge(A_src, None, tasklet, 'a', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(B_src, None, tasklet, 'b', + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(tasklet, 'c', C_dst, None, + dace.Memlet.simple(C_node, C_outer_range)) + + +def matrix_multiplication_s(A_label: str, + A_shape: Shape, + A_type: dace.types.typeclass, + B_label: str, + B_shape: Shape, + B_type: dace.types.typeclass, + create_C: bool = True, + C_label: str = None, + C_shape: Shape = None, + C_type: dace.types.typeclass = None, + is_A_transient: bool = False, + is_B_transient: bool = False, + is_C_transient: bool = False, + accumulate: bool = False, + interchange: bool = True, + A_index: Index = None, + B_index: Index = None, + C_index: Index = None, + label: str = None) -> State: + """ Creates a new state with a matrix multiplication operation. """ + + # Set output attributes + if create_C: + if C_label is None: + C_label = A_label + B_label + if C_type is None: + C_type = A_type + C_shape = [A_shape[-2], B_shape[-1]] + else: + if C_shape is None: + raise ValidationError( + 'Array C is not transient, but its shape is not set') + + # Validate input + map_ranges, A_ranges, B_ranges, C_ranges = validate_matrix_multiplication( + A_shape, B_shape, C_shape, A_index, B_index, C_index) + + # Extract ranges + M_range, N_range, K_range = map_ranges + A_outer_range, A_middle_range, A_inner_range = A_ranges + B_outer_range, B_middle_range, B_inner_range = B_ranges + C_outer_range, C_middle_range, C_inner_range = C_ranges + + # Set label + if label is None: + label = A_label + B_label + + # Create state + state = State(label=label) + + # Create data nodes + A_node = state.add_array( + A_label, A_shape, A_type, transient=is_A_transient) + B_node = state.add_array( + B_label, B_shape, B_type, transient=is_B_transient) + C_node = state.add_array( + C_label, C_shape, C_type, transient=is_C_transient or create_C) + + # Create maps/tasklet + k_entry, k_exit = state.add_map( + name=label + '_' + 'k_map', + ndrange=dict(ik=N_range), + schedule=dace.types.ScheduleType.Sequential) + k_entry.in_connectors = {'IN_1', 'IN_2'} + k_entry.out_connectors = {'OUT_1', 'OUT_2'} + k_exit.in_connectors = {'IN_1'} + k_exit.out_connectors = {'OUT_1'} + ij_entry, ij_exit = state.add_map( + name=label + '_' + 'ij_map', ndrange=dict(ii=M_range, ij=K_range)) + tasklet = state.add_tasklet( + name=label + '_' + 'tasklet', + inputs={'a', 'b'}, + outputs={'c'}, + code='c = a * b') + ij_entry.in_connectors = {'IN_1', 'IN_2'} + ij_entry.out_connectors = {'OUT_1', 'OUT_2'} + ij_exit.in_connectors = {'IN_1'} + ij_exit.out_connectors = {'OUT_1'} + + # Add edges + if interchange: + state.add_edge(A_node, None, k_entry, 'IN_1', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(B_node, None, k_entry, 'IN_2', + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(k_entry, 'OUT_1', ij_entry, 'IN_1', + dace.Memlet.simple(A_node, A_middle_range)) + state.add_edge(k_entry, 'OUT_2', ij_entry, 'IN_2', + dace.Memlet.simple(B_node, B_middle_range)) + state.add_edge(ij_entry, 'OUT_1', tasklet, 'a', + dace.Memlet.simple(A_node, A_inner_range)) + state.add_edge(ij_entry, 'OUT_2', tasklet, 'b', + dace.Memlet.simple(B_node, B_inner_range)) + wcr = 0 + if accumulate: + wcr = None + state.add_edge( + tasklet, 'c', ij_exit, 'IN_1', + dace.Memlet.simple( + C_node, + C_inner_range, + wcr_str='lambda x, y: x + y', + wcr_identity=wcr, + wcr_conflict=False)) + state.add_edge(ij_exit, 'OUT_1', k_exit, 'IN_1', + dace.Memlet.simple(C_node, C_middle_range)) + state.add_edge(k_exit, 'OUT_1', C_node, None, + dace.Memlet.simple(C_node, C_outer_range)) + else: + state.add_edge(A_node, None, ij_entry, 'IN_1', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(B_node, None, ij_entry, 'IN_2', + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(ij_entry, 'OUT_1', k_entry, 'IN_1', + dace.Memlet.simple(A_node, A_middle_range)) + state.add_edge(ij_entry, 'OUT_2', k_entry, 'IN_2', + dace.Memlet.simple(B_node, B_middle_range)) + state.add_edge(k_entry, 'OUT_1', tasklet, 'a', + dace.Memlet.simple(A_node, A_inner_range)) + state.add_edge(k_entry, 'OUT_2', tasklet, 'b', + dace.Memlet.simple(B_node, B_inner_range)) + wcr = 0 + if accumulate: + wcr = None + state.add_edge( + tasklet, 'c', k_exit, 'IN_1', + dace.Memlet.simple( + C_node, + C_inner_range, + wcr_str='lambda x, y: x + y', + wcr_identity=wcr, + wcr_conflict=False)) + state.add_edge(k_exit, 'OUT_1', ij_exit, 'IN_1', + dace.Memlet.simple(C_node, C_middle_range)) + state.add_edge(ij_exit, 'OUT_1', C_node, None, + dace.Memlet.simple(C_node, C_outer_range)) + + return state + + +def validate_scalar_array_multiplication( + alpha_shape: Shape, + A_shape: Shape, + B_shape: Shape, + alpha_index: Index = None, + A_index: Index = None, + B_index: Index = None +) -> (typing.Dict[str, str], (str, str), (str, str), (str, str)): + """ Validates a scalar-array multiplication operation, based on the shapes + and indices of the arrays involved. Returns the ranges of the maps and + memlets at all levels as strings. """ + + # Validate data + if alpha_shape != [1]: + if alpha_index is None or len(alpha_shape) != len(alpha_index): + raise ValidationError( + 'Slice of alpha is not a scalar: {}, {}'.format( + alpha_shape, alpha_index)) + if A_index is not None: + true_A_shape = A_shape[len(A_index):] + else: + true_A_shape = A_shape + if B_index is not None: + true_B_shape = B_shape[len(B_index):] + else: + true_B_shape = B_shape + if true_A_shape != true_B_shape: + raise ValidationError('Dimension mismatch between arrays A and B: ' + '{}({}) != {}({})'.format( + true_A_shape, A_shape, true_B_shape, + B_shape)) + + # Map ranges + map_ranges = dict() + for i, dim in enumerate(true_A_shape): + map_ranges['i{}'.format(i)] = '0:{}'.format(dim) + + # Memlet ranges + alpha_outer_range = '0' + alpha_inner_range = '0' + if alpha_index is not None: + alpha_index = [str(idx) for idx in alpha_index] + alpha_outer_range = ', '.join(alpha_index) + alpha_inner_range = ', '.join(alpha_index) + A_outer_range = ', '.join(map_ranges.values()) + A_inner_range = ', '.join(map_ranges.keys()) + if A_index is not None: + A_index = [str(idx) for idx in A_index] + A_outer_range = '{}, {}'.format(', '.join(A_index), A_outer_range) + A_inner_range = '{}, {}'.format(', '.join(A_index), A_inner_range) + B_outer_range = ', '.join(map_ranges.values()) + B_inner_range = ', '.join(map_ranges.keys()) + if B_index is not None: + B_index = [str(idx) for idx in B_index] + B_outer_range = '{}, {}'.format(', '.join(B_index), B_outer_range) + B_inner_range = '{}, {}'.format(', '.join(B_index), B_inner_range) + + return (map_ranges, (alpha_outer_range, alpha_inner_range), + (A_outer_range, A_inner_range), (B_outer_range, B_inner_range)) + + +def scalar_array_multiplication(state: State, + alpha_src: Node, + alpha_node: DNode, + A_src: Node, + A_node: DNode, + B_dst: Node, + B_node: DNode, + accumulate: bool = False, + wcr_conflict: bool = False, + alpha_index: Index = None, + A_index: Index = None, + B_index: Index = None, + label: str = None): + """ Adds a scalar-array multiplication operation to an exisiting state. """ + + # Validate data + sdfg = state.parent + alpha_shape = [1] + if hasattr(alpha_node, 'shape'): + alpha_shape = alpha_node.shape + ranges = validate_scalar_array_multiplication( + alpha_shape, + A_node.desc(sdfg).shape, + B_node.desc(sdfg).shape, alpha_index, A_index, B_index) + map_ranges, alpha_ranges, A_ranges, B_ranges = ranges + alpha_outer_range, alpha_inner_range = alpha_ranges + A_outer_range, A_inner_range = A_ranges + A_outer_range, A_inner_range = A_ranges + B_outer_range, B_inner_range = B_ranges + + # Set label + if label is None: + label = state.label + + # Create map/tasklet + map_entry, map_exit = state.add_map( + name=label + '_map', ndrange=map_ranges) + map_entry.in_connectors = {'IN_1', 'IN_2'} + map_entry.out_connectors = {'OUT_1', 'OUT_2'} + map_exit.in_connectors = {'IN_1'} + map_exit.out_connectors = {'OUT_1'} + tasklet = state.add_tasklet( + name=label + '_tasklet', + inputs={'scalar', 'a'}, + outputs={'b'}, + code='b = scalar * a') + + # Add edges + state.add_edge(alpha_src, None, map_entry, 'IN_1', + dace.Memlet.simple(alpha_node, alpha_outer_range)) + state.add_edge(A_src, None, map_entry, 'IN_2', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(map_exit, 'OUT_1', B_dst, None, + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(map_entry, 'OUT_1', tasklet, 'scalar', + dace.Memlet.simple(alpha_node, alpha_inner_range)) + state.add_edge(map_entry, 'OUT_2', tasklet, 'a', + dace.Memlet.simple(A_node, A_inner_range)) + if accumulate: + state.add_edge( + tasklet, 'b', map_exit, 'IN_1', + dace.Memlet.simple( + B_node, + B_inner_range, + wcr_str='lambda x, y: x + y', + wcr_identity=None, + wcr_conflict=wcr_conflict)) + else: + state.add_edge(tasklet, 'b', map_exit, 'IN_1', + dace.Memlet.simple(B_node, B_inner_range)) + + +def scalar_array_multiplication_s(alpha_label: str, + alpha_shape: Shape, + alpha_type: dace.types.typeclass, + A_label: str, + A_shape: Shape, + A_type: dace.types.typeclass, + create_B: bool = True, + B_label: str = None, + B_shape: Shape = None, + B_type: dace.types.typeclass = None, + is_alpha_transient: bool = False, + is_A_transient: bool = False, + is_B_transient: bool = False, + accumulate: bool = False, + wcr_conflict: bool = False, + alpha_index: Index = None, + A_index: Index = None, + B_index: Index = None, + label: str = None) -> State: + """ Creates a new state with a scalar-array multiplication operation. """ + + # Set output attributes + if create_B: + if B_label is None: + B_label = alpha_label + A_label + if B_type is None: + B_type = A_type + B_shape = A_shape + else: + if B_shape is None: + raise ValidationError( + 'Array B is not transient, but its shape is not set') + + # Validate data + ranges = validate_scalar_array_multiplication( + alpha_shape, A_shape, B_shape, alpha_index, A_index, B_index) + map_ranges, alpha_ranges, A_ranges, B_ranges = ranges + alpha_outer_range, alpha_inner_range = alpha_ranges + A_outer_range, A_inner_range = A_ranges + A_outer_range, A_inner_range = A_ranges + B_outer_range, B_inner_range = B_ranges + + # Set label + if label is None: + label = alpha_label + A_label + + # Create state + state = State(label=label) + + # Create data nodes + alpha_node = state.add_array( + alpha_label, alpha_shape, alpha_type, transient=is_alpha_transient) + A_node = state.add_array( + A_label, A_shape, A_type, transient=is_A_transient) + B_node = state.add_array( + B_label, B_shape, B_type, transient=is_B_transient or create_B) + + # Create map/tasklet + map_entry, map_exit = state.add_map( + name=label + '_map', ndrange=map_ranges) + map_entry.in_connectors = {'IN_1', 'IN_2'} + map_entry.out_connectors = {'OUT_1', 'OUT_2'} + map_exit.in_connectors = {'IN_1'} + map_exit.out_connectors = {'OUT_1'} + tasklet = state.add_tasklet( + name=label + '_tasklet', + inputs={'scalar', 'a'}, + outputs={'b'}, + code='b = scalar * a') + + # Add edges + state.add_edge(alpha_node, None, map_entry, 'IN_1', + dace.Memlet.simple(alpha_node, alpha_outer_range)) + state.add_edge(A_node, None, map_entry, 'IN_2', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(map_exit, 'OUT_1', B_node, None, + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(map_entry, 'OUT_1', tasklet, 'scalar', + dace.Memlet.simple(alpha_node, alpha_inner_range)) + state.add_edge(map_entry, 'OUT_2', tasklet, 'a', + dace.Memlet.simple(A_node, A_inner_range)) + if accumulate: + state.add_edge( + tasklet, 'b', map_exit, 'IN_1', + dace.Memlet.simple( + B_node, + B_inner_range, + wcr_str='lambda x, y: x + y', + wcr_identity=None, + wcr_conflict=wcr_conflict)) + else: + state.add_edge(tasklet, 'b', map_exit, 'IN_1', + dace.Memlet.simple(B_node, B_inner_range)) + + return state + + +def constant_array_multiplication(state: State, + constant, + A_src: Node, + A_node: DNode, + B_dst: Node, + B_node: DNode, + accumulate: bool = False, + A_index: Index = None, + B_index: Index = None, + label: str = None): + """ Adds a scalar-array multiplication operation to an exisiting state. """ + + # Validate data + # ranges = validate_scalar_array_multiplication( + # [1], A_node.shape, B_node.shape, + # None, A_index, B_index + # ) + sdfg = state.parent + ranges = validate_scalar_array_multiplication([1], + A_node.desc(sdfg).shape, + B_node.desc(sdfg).shape, + None, A_index, B_index) + map_ranges, _, A_ranges, B_ranges = ranges + A_outer_range, A_inner_range = A_ranges + B_outer_range, B_inner_range = B_ranges + + # Set label + if label is None: + label = state.label + + # Create map/tasklet + map_entry, map_exit = state.add_map( + name=label + '_map', ndrange=map_ranges) + map_entry.in_connectors = {'IN_1'} + map_entry.out_connectors = {'OUT_1'} + map_exit.in_connectors = {'IN_1'} + map_exit.out_connectors = {'OUT_1'} + tasklet = state.add_tasklet( + name=label + '_tasklet', + inputs={'a'}, + outputs={'b'}, + code='b = {} * a'.format(constant)) + + # Add edges + state.add_edge(A_src, None, map_entry, 'IN_1', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(map_exit, 'OUT_1', B_dst, None, + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(map_entry, 'OUT_1', tasklet, 'a', + dace.Memlet.simple(A_node, A_inner_range)) + if accumulate: + state.add_edge( + tasklet, 'b', map_exit, 'IN_1', + dace.Memlet.simple( + B_node, + B_inner_range, + wcr_str='lambda x, y: x + y', + wcr_identity=None, + wcr_conflict=False)) + else: + state.add_edge(tasklet, 'b', map_exit, 'IN_1', + dace.Memlet.simple(B_node, B_inner_range)) + + +def unary_array_op(state: State, + A_src: Node, + A_node: DNode, + B_dst: Node, + B_node: DNode, + code: str, + lang=dace.types.Language.Python, + accumulate: bool = False, + A_index: Index = None, + B_index: Index = None, + label: str = None): + """ Adds a unary array operation to an exisiting state. """ + + # Validate data + sdfg = state.parent + ranges = validate_scalar_array_multiplication([1], + A_node.desc(sdfg).shape, + B_node.desc(sdfg).shape, + None, A_index, B_index) + map_ranges, _, A_ranges, B_ranges = ranges + A_outer_range, A_inner_range = A_ranges + B_outer_range, B_inner_range = B_ranges + + # Set label + if label is None: + label = state.label + + # Create map/tasklet + map_entry, map_exit = state.add_map( + name=label + '_map', ndrange=map_ranges) + map_entry.in_connectors = {'IN_1'} + map_entry.out_connectors = {'OUT_1'} + map_exit.in_connectors = {'IN_1'} + map_exit.out_connectors = {'OUT_1'} + tasklet = state.add_tasklet( + name=label + '_tasklet', + inputs={'a'}, + outputs={'b'}, + code=code, + language=lang) + + # Add edges + state.add_edge(A_src, None, map_entry, 'IN_1', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(map_exit, 'OUT_1', B_dst, None, + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(map_entry, 'OUT_1', tasklet, 'a', + dace.Memlet.simple(A_node, A_inner_range)) + if accumulate: + state.add_edge( + tasklet, 'b', map_exit, 'IN_1', + dace.Memlet.simple( + B_node, + B_inner_range, + wcr_str='lambda x, y: x + y', + wcr_identity=None, + wcr_conflict=False)) + else: + state.add_edge(tasklet, 'b', map_exit, 'IN_1', + dace.Memlet.simple(B_node, B_inner_range)) + + +def validate_matrix_transpose( + A_shape: Shape, + B_shape: Shape, + A_index: Index = None, + B_index: Index = None +) -> (typing.Dict[str, str], (str, str), (str, str)): + """ Validates a matrix transpose operation, based on the shapes and indices + of the arrays involved. Returns the ranges of the maps and memlets at + all levels as strings. """ + + # Validate data + if len(A_shape) < 2: + raise ValidationError( + 'Array A has less than 2 dimensions: {}'.format(A_shape)) + A_tr_shape = A_shape[-2:] + if len(B_shape) < 2: + raise ValidationError( + 'Array B has less than 2 dimensions: {}'.format(B_shape)) + B_tr_shape = B_shape[-2:] + if A_tr_shape[0] != B_tr_shape[-1] or A_tr_shape[-1] != B_tr_shape[0]: + raise ValidationError( + 'Dimension mismatch between arrays A and B: {} != {}'.format( + A_tr_shape, B_tr_shape)) + + # Map ranges + map_ranges = dict( + ii='0:{}'.format(A_tr_shape[0]), ij='0:{}'.format(A_tr_shape[-1])) + + # Validate slices and set array access ranges + A_outer_range = '0:{}, 0:{}'.format(A_tr_shape[0], A_tr_shape[-1]) + A_inner_range = 'ii, ij' + if len(A_shape) > 2: + if A_index is None or len(A_index) != len(A_shape) - 2: + raise ValidationError( + 'Invalid slice {} for array A with dimensions {}'.format( + A_index, A_shape)) + A_index = [str(idx) for idx in A_index] + A_outer_range = '{}, {}'.format(', '.join(A_index), A_outer_range) + A_inner_range = '{}, {}'.format(', '.join(A_index), A_inner_range) + B_outer_range = '0:{}, 0:{}'.format(A_tr_shape[-1], A_tr_shape[0]) + B_inner_range = 'ij, ii' + if len(B_shape) > 2: + if B_index is None or len(B_index) != len(B_shape) - 2: + raise ValidationError( + 'Invalid slice {} for array B with dimensions {}'.format( + B_index, B_shape)) + B_index = [str(idx) for idx in B_index] + B_outer_range = '{}, {}'.format(', '.join(B_index), B_outer_range) + B_inner_range = '{}, {}'.format(', '.join(B_index), B_inner_range) + + return (map_ranges, (A_outer_range, A_inner_range), (B_outer_range, + B_inner_range)) + + +def matrix_transpose(state: State, + A_src: Node, + A_node: DNode, + B_dst: Node, + B_node: DNode, + A_index: Index = None, + B_index: Index = None, + code: str = None, + lang=dace.types.Language.Python, + label: str = None): + """ Adds a matrix transpose operation to an existing state. """ + + # Validate data + sdfg = state.parent + map_ranges, A_ranges, B_ranges = validate_matrix_transpose( + A_node.desc(sdfg).shape, + B_node.desc(sdfg).shape, A_index, B_index) + A_outer_range, A_inner_range = A_ranges + B_outer_range, B_inner_range = B_ranges + + # Set label + if label is None: + label = state.label + + # Create map/tasklet + if code is None: + code = 'b = a' + _, map_entry, map_exit = state.add_mapped_tasklet( + name=label, + map_ranges=map_ranges, + inputs=dict(a=dace.Memlet.simple(A_node, A_inner_range)), + outputs=dict(b=dace.Memlet.simple(B_node, B_inner_range)), + code=code, + language=lang) + + # Add edges + state.add_nedge(A_src, map_entry, dace.Memlet.simple( + A_node, A_outer_range)) + state.add_nedge(map_exit, B_dst, dace.Memlet.simple(B_node, B_outer_range)) + + return state + + +def matrix_transpose_double(state: State, + A_src: Node, + A_node: DNode, + B_dst: Node, + B_node: DNode, + C_dst: Node, + C_node: DNode, + A_index: Index = None, + B_index: Index = None, + C_index: Index = None, + code: str = None, + lang=dace.types.Language.Python, + label: str = None): + """ Adds a matrix transpose operation, which transposes to two different + matrices, to an existing state. """ + + # Validate data + sdfg = state.parent + map_ranges, A_ranges, B_ranges = validate_matrix_transpose( + A_node.desc(sdfg).shape, + B_node.desc(sdfg).shape, A_index, B_index) + A_outer_range, A_inner_range = A_ranges + B_outer_range, B_inner_range = B_ranges + _, _, C_ranges = validate_matrix_transpose( + A_node.desc(sdfg).shape, + C_node.desc(sdfg).shape, A_index, C_index) + C_outer_range, C_inner_range = C_ranges + + # Set label + if label is None: + label = state.label + + # Create map/tasklet + if code is None: + code = ''' +b = a +c = a + ''' + _, map_entry, map_exit = state.add_mapped_tasklet( + name=label, + map_ranges=map_ranges, + inputs=dict(a=dace.Memlet.simple(A_node, A_inner_range)), + outputs=dict( + b=dace.Memlet.simple(B_node, B_inner_range), + c=dace.Memlet.simple(C_node, C_inner_range), + ), + code=code, + language=lang) + + # Add edges + state.add_nedge(A_src, map_entry, dace.Memlet.simple( + A_node, A_outer_range)) + state.add_nedge(map_exit, B_dst, dace.Memlet.simple(B_node, B_outer_range)) + state.add_nedge(map_exit, C_dst, dace.Memlet.simple(C_node, C_outer_range)) + + return state + + +def matrix_transpose_s(A_label: str, + A_shape: Shape, + A_type: dace.types.typeclass, + create_B: bool = True, + B_label: str = None, + B_shape: Shape = None, + B_type: dace.types.typeclass = None, + is_alpha_transient: bool = False, + is_A_transient: bool = False, + is_B_transient: bool = False, + A_index: Index = None, + B_index: Index = None, + label: str = None) -> State: + """ Creates a new state with a matrix transpose operation. """ + + # Set output attributes + if create_B: + if B_label is None: + B_label = A_label + '^T' + if B_type is None: + B_type = A_type + B_shape = list(A_shape).reverse() + else: + if B_shape is None: + raise ValidationError( + 'Array B is not transient, but its shape is not set') + + # Validate data + map_ranges, A_ranges, B_ranges = validate_matrix_transpose( + A_shape, B_shape, A_index, B_index) + A_outer_range, A_inner_range = A_ranges + B_outer_range, B_inner_range = B_ranges + + # Set label + if label is None: + label = A_label + '^T' + + # Create state + state = State(label=label) + + # Create datanodes + A_node = state.add_array( + A_label, A_shape, A_type, transient=is_A_transient) + B_node = state.add_array( + B_label, B_shape, B_type, transient=is_B_transient or create_B) + + # Create map/tasklet + _, map_entry, map_exit = state.add_mapped_tasklet( + name=label, + map_ranges=map_ranges, + inputs=dict(a=dace.Memlet.simple(A_node, A_inner_range)), + outputs=dict(b=dace.Memlet.simple(B_node, B_inner_range)), + code='b = a') + + # Add edges + state.add_nedge(A_node, map_entry, dace.Memlet.simple( + A_node, A_outer_range)) + state.add_nedge(map_exit, B_node, dace.Memlet.simple( + B_node, B_outer_range)) + + return state + + +def validate_matrix_pointwise_op( + A_shape: Shape, + B_shape: Shape, + C_shape: Shape, + reduce: bool = False, + A_index: Index = None, + B_index: Index = None, + C_index: Index = None +) -> (typing.Dict[str, str], (str, str), (str, str), (str, str)): + """ Validates a point-wise matrix operation. """ + + # Validate data + if A_index is not None: + true_A_shape = A_shape[len(A_index):] + else: + true_A_shape = A_shape + if B_index is not None: + true_B_shape = B_shape[len(B_index):] + else: + true_B_shape = B_shape + if true_A_shape != true_B_shape: + raise ValidationError('Dimension mismatch between arrays A and B: ' + '{}({}) != {}({})'.format( + true_A_shape, A_shape, true_B_shape, + B_shape)) + if reduce: + if C_index is None or len(C_shape) != len(C_index): + raise ValidationError( + 'Point-wise matrix operation result cannot be reduced: ' + '{}({})'.format(C_shape, C_index)) + else: + if C_index is not None: + true_C_shape = C_shape[len(C_index):] + else: + true_C_shape = C_shape + if true_A_shape != true_B_shape: + raise ValidationError('Dimension mismatch between arrays A and C: ' + '{}({}) != {}({})'.format( + true_A_shape, A_shape, true_C_shape, + C_shape)) + + # Map ranges + map_ranges = dict() + for i, dim in enumerate(true_A_shape): + map_ranges['i{}'.format(i)] = '0:{}'.format(dim) + + # Memlet ranges + A_outer_range = ', '.join(map_ranges.values()) + A_inner_range = ', '.join(map_ranges.keys()) + if A_index is not None: + A_index = [str(idx) for idx in A_index] + A_outer_range = '{}, {}'.format(', '.join(A_index), A_outer_range) + A_inner_range = '{}, {}'.format(', '.join(A_index), A_inner_range) + B_outer_range = ', '.join(map_ranges.values()) + B_inner_range = ', '.join(map_ranges.keys()) + if B_index is not None: + B_index = [str(idx) for idx in B_index] + B_outer_range = '{}, {}'.format(', '.join(B_index), B_outer_range) + B_inner_range = '{}, {}'.format(', '.join(B_index), B_inner_range) + if reduce: + C_index = [str(idx) for idx in C_index] + C_outer_range = ', '.join(C_index) + C_inner_range = ', '.join(C_index) + else: + C_outer_range = ', '.join(map_ranges.values()) + C_inner_range = ', '.join(map_ranges.keys()) + if C_index is not None: + C_index = [str(idx) for idx in C_index] + C_outer_range = '{}, {}'.format(', '.join(C_index), C_outer_range) + C_inner_range = '{}, {}'.format(', '.join(C_index), C_inner_range) + + return (map_ranges, (A_outer_range, A_inner_range), + (B_outer_range, B_inner_range), (C_outer_range, C_inner_range)) + + +def matrix_pointwise_op(state: State, + A_src: Node, + A_node: DNode, + B_src: Node, + B_node: DNode, + C_dst: Node, + C_node: DNode, + op: str, + reduce: bool = False, + reduce_op: str = None, + accumulate: bool = False, + A_index: Index = None, + B_index: Index = None, + C_index: Index = None, + label: str = None): + """ Adds a matrix point-wise operation to an existing state. """ + + # Validate data + sdfg = state.parent + C_shape = None + if reduce and not hasattr(C_node.desc(sdfg), 'shape'): + C_shape = [1] + else: + C_shape = C_node.desc(sdfg).shape + map_ranges, A_ranges, B_ranges, C_ranges = validate_matrix_pointwise_op( + A_node.desc(sdfg).shape, + B_node.desc(sdfg).shape, C_shape, reduce, A_index, B_index, C_index) + A_outer_range, A_inner_range = A_ranges + B_outer_range, B_inner_range = B_ranges + C_outer_range, C_inner_range = C_ranges + + # Set label + if label is None: + label = state.label + + # Create map/tasklet + if reduce: + schedule = dace.types.ScheduleType.Sequential + else: + schedule = dace.types.ScheduleType.Default + map_entry, map_exit = state.add_map( + name=label + '_map', ndrange=map_ranges, schedule=schedule) + map_entry.in_connectors = {'IN_1', 'IN_2'} + map_entry.out_connectors = {'OUT_1', 'OUT_2'} + map_exit.in_connectors = {'IN_1'} + map_exit.out_connectors = {'OUT_1'} + tasklet = state.add_tasklet( + name=label + '_tasklet', + inputs={'a', 'b'}, + outputs={'c'}, + code='c = a ' + op + ' b') + + # Add edges + state.add_edge(A_src, None, map_entry, 'IN_1', + dace.Memlet.simple(A_node, A_outer_range)) + state.add_edge(B_src, None, map_entry, 'IN_2', + dace.Memlet.simple(B_node, B_outer_range)) + state.add_edge(map_exit, 'OUT_1', C_dst, None, + dace.Memlet.simple(C_node, C_outer_range)) + state.add_edge(map_entry, 'OUT_1', tasklet, 'a', + dace.Memlet.simple(A_node, A_inner_range)) + state.add_edge(map_entry, 'OUT_2', tasklet, 'b', + dace.Memlet.simple(B_node, B_inner_range)) + if reduce: + wcr = 0 + if accumulate: + wcr = None + state.add_edge( + tasklet, 'c', map_exit, 'IN_1', + dace.Memlet.simple( + C_node, + C_inner_range, + wcr_str='lambda x, y: x ' + reduce_op + ' y', + wcr_identity=wcr, + wcr_conflict=False)) + else: + state.add_edge(tasklet, 'c', map_exit, 'IN_1', + dace.Memlet.simple(C_node, C_inner_range)) + + +def csr2dense_cusparse(state: State, val: DNode, rowptr: DNode, colind: DNode, + dense: DNode): + """ Adds a CSR->Dense data layout transformation to a state, using + CUSPARSE for the implementation. """ + sdfg = state.parent + dense_array = dense.desc(sdfg) + d_shape = dense_array.shape + d_dtype = dense_array.dtype + T = state.add_transient(dense.data + 'T', d_shape, d_dtype) + + tasklet = state.add_tasklet( + name=dense.data + '_csr2dense', + inputs={'val', 'rowptr', 'colind'}, + outputs={'dense'}, + code=''' + cusparseSetStream(sparse_handle, __dace_current_stream); + cusparseZcsr2dense( + sparse_handle, + {m}, {n}, + sparse_mat_descr, + (cuDoubleComplex*)val, + rowptr, + colind, + (cuDoubleComplex*)dense, + {m} + ); + '''.format(m=str(d_shape[0]), n=str(d_shape[1])), + language=dace.types.Language.CPP) + state.add_edge(val, None, tasklet, 'val', + dace.Memlet.from_array(val.data, val.desc(sdfg))) + state.add_edge(rowptr, None, tasklet, 'rowptr', + dace.Memlet.from_array(rowptr.data, rowptr.desc(sdfg))) + state.add_edge(colind, None, tasklet, 'colind', + dace.Memlet.from_array(colind.data, colind.desc(sdfg))) + state.add_edge(tasklet, 'dense', T, None, + dace.Memlet.from_array(T.data, T.desc(sdfg))) + gpu_transform_tasklet(sdfg, state, tasklet) + matrix_transpose(state, T, T, dense, dense, label=T.data) + + +def matrix_inversion_cusolver(state, arg, mat_inv, mat_index, label): + """ Adds a matrix inverse operation to a state, using CUSOLVER + for the implementation. """ + + sdfg = state.parent + m_shape = mat_inv.desc(sdfg).shape + inv_range = '0 : {sz}, 0 : {sz}'.format(sz=m_shape[-1]) + if mat_index is not None: + index = [str(idx) for idx in mat_index] + inv_range = '{}, {}'.format(', '.join(index), inv_range) + inv_task = state.add_tasklet( + name=label, + inputs={'a'}, + outputs={'b'}, + code=''' + cusolverDnSetStream(solver_handle, __dace_current_stream); + int new_lwork = 0; + cusolverDnZgetrf_bufferSize( + solver_handle, + {n}, {n}, + (cuDoubleComplex*)a, + {n}, + &new_lwork + ); + //cudaDeviceSynchronize(); + if (new_lwork > lwork) {{ + lwork = new_lwork; + cudaFree(dwork); + cudaMalloc(&dwork, sizeof(cuDoubleComplex) * lwork); + }} + cusolverDnZgetrf( + solver_handle, + {n}, {n}, + (cuDoubleComplex*)a, + {n}, + dwork, ipiv, info + ); + //cudaDeviceSynchronize(); + cudaMemcpyAsync(b, dev_I, sizeof(cuDoubleComplex) * {n} * {n}, cudaMemcpyDeviceToDevice, __dace_current_stream); + cusolverDnZgetrs( + solver_handle, + CUBLAS_OP_N, + {n}, + {n}, /* nrhs */ + (cuDoubleComplex*)a, + {n}, + ipiv, + (cuDoubleComplex*)b, + {n}, + info + ); + //cudaDeviceSynchronize(); + '''.format(n=m_shape[-1]), + language=dace.types.Language.CPP) + state.add_edge(arg, None, inv_task, 'a', + dace.Memlet.from_array(arg.data, arg.desc(sdfg))) + state.add_edge(inv_task, 'b', mat_inv, None, + dace.Memlet.simple(mat_inv, inv_range)) + gpu_transform_tasklet(sdfg, state, inv_task) diff --git a/dace/frontend/octave/__init__.py b/dace/frontend/octave/__init__.py new file mode 100644 index 0000000000..362d8c7d52 --- /dev/null +++ b/dace/frontend/octave/__init__.py @@ -0,0 +1 @@ +from .ast_node import AST_Node, AST_Statements \ No newline at end of file diff --git a/dace/frontend/octave/ast_arrayaccess.py b/dace/frontend/octave/ast_arrayaccess.py new file mode 100644 index 0000000000..c78a98920c --- /dev/null +++ b/dace/frontend/octave/ast_arrayaccess.py @@ -0,0 +1,217 @@ +import dace + +from .ast_node import AST_Node + + +class AST_ArrayAccess(AST_Node): + def __init__(self, context, arrayname, accdims): + AST_Node.__init__(self, context) + self.arrayname = arrayname + self.accdims = accdims + + def __repr__(self): + return "AST_ArrayAccess(" + str(self.arrayname) + ", " + str( + self.accdims) + ")" + + def get_children(self): + ret = [self.arrayname] + ret += self.accdims + return ret + + def replace_child(self, old, new): + if old == self.arrayname: + self.arrayname = new + return + if old in self.accdims: + newaccdims = [new if x == old else x for x in self.accdims] + self.accdims = newaccdims + + def get_basetype(self): + # The basetype of an array access is the same as the basetype as the + # array that is acccessed. + vardef = self.search_vardef_in_scope(self.arrayname.get_name()) + return (vardef.get_basetype()) + + def get_dims(self): + from .ast_matrix import AST_Matrix + from .ast_loop import AST_ForLoop + from .ast_values import AST_Constant, AST_Ident + from .ast_range import AST_RangeExpression + # array indexing has many forms/cases in matlab and does not seem to + # be fully documented, the idea is to implement the simple things + # we are sure about and bail out on anything that looks different + dims = [] + if isinstance(self.accdims, list): + for acc in self.accdims: + if isinstance(acc, AST_Constant): + dims.append(1) + elif isinstance(acc, AST_Matrix): + dims.append(len(acc.get_values_row_major())) + elif isinstance(acc, AST_RangeExpression): + if isinstance(acc.lhs, AST_Constant) and isinstance( + acc.rhs, AST_Constant): + l = acc.lhs.get_value() + r = acc.rhs.get_value() + dims.append(r - l + 1) + elif (acc.lhs is None) and (acc.rhs is None): + # Get the dims of the array itself + vardef = self.search_vardef_in_scope( + self.arrayname.get_name()) + if vardef is None: + raise ValueError("No definition found for Array " + + self.arrayname.get_name()) + d = vardef.get_dims() + dims.append(d[len(dims)]) + else: + raise NotImplementedError( + "range with non-constant bounds not supported") + elif isinstance(acc, AST_Ident): + vardef = self.search_vardef_in_scope(acc.get_name()) + if vardef is None: + raise ValueError( + "No definition found for " + acc.get_name() + + " which is used in Array Access: " + str(self)) + if isinstance(vardef, AST_ForLoop) and acc.get_name( + ) == vardef.var.get_name(): + d = vardef.initializer.get_dims()[:-1] + if d != [1]: + raise NotImplementedError( + "Complicated slicing not implemented yet.") + else: + dims.append(d[0]) + else: + raise NotImplementedError( + "unimplemented method of array access (" + str(acc) + + ")") + else: + raise NotImplementedError("unimplemented method of array access") + + # simplify [1,1] to [1] + if dims == [1, 1]: + dims = [1] + return dims + + def make_range_from_accdims(self): + from .ast_range import AST_RangeExpression + from .ast_values import AST_Constant + + rangelist = [] + for acc in self.accdims: + if isinstance(acc, AST_Constant): + rangelist.append((acc.get_value() - 1, acc.get_value() - 1, 1)) + elif isinstance(acc, AST_RangeExpression): + if isinstance(acc.lhs, AST_Constant) and isinstance( + acc.rhs, AST_Constant): + l = acc.lhs.get_value() + r = acc.rhs.get_value() + rangelist.append((l, r, 1)) + else: + raise NotImplementedError( + "range with non-constant bounds not supported: " + + str(self)) + else: + raise NotImplementedError( + "Non-constant array indexing not implemented: " + + str(self)) + ret = dace.subsets.Range(rangelist) + return ret + + def is_data_dependent_access(self): + from .ast_values import AST_Constant + res = False + for a in self.accdims: + if not isinstance(a, AST_Constant): + return True + + def generate_code(self, sdfg, state): + from .ast_values import AST_Ident + from .ast_loop import AST_ForLoop + from .ast_range import AST_RangeExpression + # add a new variable to hold the result of this expression + dims = self.get_dims() + basetype = self.get_basetype() + name = self.get_name_in_sdfg(sdfg) + if name not in sdfg.arrays: + sdfg.add_transient(name, dims, basetype, debuginfo=self.context) + # add a memlet from the original array to the transient + resnode = self.get_datanode(sdfg, state) + arrnode = self.arrayname.get_datanode(sdfg, state) + arrdesc = arrnode.desc(sdfg) + + if self.is_data_dependent_access() == False: + msubset = self.make_range_from_accdims() + memlet = dace.memlet.Memlet( + arrnode, + msubset.num_elements(), + msubset, + 1, + None, + None, + debuginfo=self.context) + sdfg.nodes()[state].add_edge(arrnode, None, resnode, None, memlet) + else: + # add a map around the access and feed the access dims that are + # runtime dependent into a connector which is _not_ named IN + access_data_nodes = set() + access_dims = [] + for idx, acc in enumerate(self.accdims): + if isinstance(acc, AST_Ident): + vardef = self.search_vardef_in_scope(acc.get_name()) + if vardef is None: + raise ValueError('No definition found for ' + + str(acc.get_name())) + elif isinstance(vardef, AST_ForLoop): + access_data_nodes.add(vardef.var) + access_dims.append(vardef.var.get_name()) + elif isinstance(acc, AST_RangeExpression): + # if the bounds are identifiers, we need them on the map + # otherwise we do not need to do anything here + if isinstance(acc.lhs, AST_Ident): + access_data_nodes.add(acc.lhs) + if isinstance(acc.rhs, AST_Ident): + access_data_nodes.add(acc.rhs) + if (acc.lhs is None) and (acc.rhs is None): + d = arrdesc.shape + access_dims.append('0:' + str(d[idx])) + else: + acc.generate_code(sdfg, state) + access_data_nodes.add(acc) + access_dims.append(acc.get_name_in_sdfg(sdfg)) + # now construct the dictionary for the map range + s = sdfg.nodes()[state] + mdict = {} + for aa in access_data_nodes: + a = aa.get_name_in_sdfg(sdfg) + mdict[a] = a + if len(mdict) == 0: + mdict = {'__DAPUNUSED_i': '0:1'} + men, mex = s.add_map('datadepacc', mdict) + men._in_connectors.add('IN_1') + men._out_connectors.add('OUT_1') + s.add_edge(arrnode, None, men, 'IN_1', + dace.memlet.Memlet.from_array(arrnode.data, arrdesc)) + for a in access_data_nodes: + aname = a.get_name_in_sdfg(sdfg) + men._in_connectors.add(aname) + datanode = a.get_datanode(sdfg, state) + s.add_edge( + datanode, None, men, aname, + dace.memlet.Memlet.from_array(datanode.data, + datanode.desc(sdfg))) + tasklet = s.add_tasklet('ident', {'in'}, {'out'}, 'in=out;', + dace.Language.CPP) + s.add_edge( + men, 'OUT_1', tasklet, 'in', + dace.memlet.Memlet.simple(arrnode, ','.join(access_dims))) + s.add_edge( + tasklet, 'out', mex, None, + dace.memlet.Memlet.from_array(resnode.data, + resnode.desc(sdfg))) + s.add_edge( + mex, None, resnode, None, + dace.memlet.Memlet.from_array(resnode.data, + resnode.desc(sdfg))) + + print("The result of " + str(self) + " will be stored in " + str(name)) + + __str__ = __repr__ diff --git a/dace/frontend/octave/ast_assign.py b/dace/frontend/octave/ast_assign.py new file mode 100644 index 0000000000..25a52db783 --- /dev/null +++ b/dace/frontend/octave/ast_assign.py @@ -0,0 +1,174 @@ +from .ast_node import AST_Node +from .ast_values import AST_Ident + +import dace + + +class AST_Assign(AST_Node): + def __init__(self, context, lhs, rhs, op): + # for a normal assignment op is "=", but there is also + # in place modification, i.e., "+=" + AST_Node.__init__(self, context) + self.lhs = lhs + self.rhs = rhs + self.op = op + self.children = [self.lhs, self.rhs] + + def get_children(self): + retval = [self.lhs, self.rhs] + return retval + + def replace_child(self, old, new): + if old == self.lhs: + self.lhs = new + if old == self.rhs: + self.rhs = new + + def defined_variables(self): + # check if this adds something to the scope, if yes add it. + # assume A is undefined before this node, then: + # A = expr defines A, A(5) = expr defines A, but + # A += expr or A(5) += expr is illegal. + if self.op == "=": + if isinstance(self.lhs, AST_Ident): + return [self.lhs.get_name()] + else: + return [] + + def provide_parents(self, parent): + self.parent = parent + self.lhs.provide_parents(self) + self.rhs.provide_parents(self) + + def __repr__(self): + return "AST_Assign(" + str(self.lhs) + ", " + str( + self.op) + ", " + str(self.rhs) + ")" + + def print_nodes(self, state): + for n in state.nodes(): + print(str(n)) + print("---") + + def generate_code(self, sdfg, state): + from .ast_arrayaccess import AST_ArrayAccess + from .ast_values import AST_Constant + from .ast_loop import AST_ForLoop + + self.rhs.generate_code(sdfg, state) + s = sdfg.nodes()[state] + if self.op == "=": + # We assign to an entire array + if isinstance(self.lhs, AST_Ident): + dims = self.rhs.get_dims() + basetype = self.rhs.get_basetype() + name = self.lhs.get_name() + + if name not in sdfg.arrays: + sdfg.add_array( + name, dims, basetype, debuginfo=self.context) + rhs_datanode = self.rhs.get_datanode(sdfg, state) + lhs_datanode = self.lhs.get_datanode(sdfg, state) + + s.add_edge( + rhs_datanode, None, lhs_datanode, None, + dace.memlet.Memlet.from_array(lhs_datanode.data, + lhs_datanode.desc(sdfg))) + + # We assign only to a part of an (existing) array, in order to not + # create cycles we need to add a new data-node, the add_array() + # interface will make sure it is connected to the same memory than + # the existing array node. + elif isinstance(self.lhs, AST_ArrayAccess): + # get the definition of the array we are assigning to + lhs_data = self.lhs.arrayname.get_datanode(sdfg, state) + vardef = self.search_vardef_in_scope( + self.lhs.arrayname.get_name()) + if vardef == None: + raise ValueError("No definition found for " + + self.lhs.arrayname.get_name() + + " searching from " + str(self)) + dims = vardef.get_dims() + basetype = vardef.get_basetype() + if self.lhs.arrayname.get_name() not in sdfg.arrays: + sdfg.add_array( + self.lhs.arrayname.get_name(), + dims, + basetype, + debuginfo=self.context) + dn = sdfg.nodes()[state].add_access( + self.lhs.arrayname.get_name()) + + # check if the write is "out of bounds": this _is_ allowed in + # matlab, but not in SDFGs, since it would require to + # dynamically reallocate the array + + # create a memlet which connects the rhs of the assignment to dn + rhs_datanode = self.rhs.get_datanode(sdfg, state) + + if self.lhs.is_data_dependent_access() == False: + msubset = self.lhs.make_range_from_accdims() + writem = dace.memlet.Memlet( + self.lhs.arrayname.get_name(), + msubset.num_elements(), + msubset, + 1, + None, + None, + debuginfo=self.context) + + sdfg.nodes()[state].add_edge(rhs_datanode, None, dn, None, + writem) + else: + s = sdfg.nodes()[state] + acc_data_nodes = set() + acc_dims = [] + for a in self.lhs.accdims: + if isinstance(a, AST_Constant): + acc_dims.append(a.get_value()) + elif isinstance(a, AST_Ident): + vardef = self.search_vardef_in_scope(a.get_name()) + if vardef is None: + raise ValueError('No definition found for ' + + str(acc.get_name())) + elif isinstance(vardef, AST_ForLoop): + acc_data_nodes.add(vardef.var) + acc_dims.append(vardef.var.get_name()) + else: + raise ValueError( + str(type(a)) + + " in data dependent write not allowed.") + mapdict = {} + for a in acc_dims: + mapdict[a] = str(a) + men, mex = s.add_map('datedepwrite', mapdict) + men._in_connectors.add( + 'IN_1') # the data to write goes here + men._out_connectors.add('OUT_1') # and comes out here + for d in acc_data_nodes: + dname = d.get_name_in_sdfg(sdfg) + men._in_connectors.add(dname) + datanode = d.get_datanode(sdfg, state) + s.add_edge( + datanode, None, men, dname, + dace.memlet.Memlet.from_array( + datanode.data, datanode.desc(sdfg))) + s.add_edge( + rhs_datanode, None, men, 'IN_1', + dace.memlet.Memlet.from_array(rhs_datanode.data, + rhs_datanode.desc(sdfg))) + s.add_edge( + men, 'OUT_1', dn, None, + dace.memlet.Memlet.simple( + self.lhs.arrayname.get_name(), + ','.join([str(d) for d in acc_dims]))) + s.add_edge(dn, None, mex, None, dace.memlet.EmptyMemlet()) + + else: + raise NotImplementedError("Assignment with lhs of type " + + str(type(self.lhs)) + + " has not been implemented yet.") + else: + raise NotImplementedError("Assignment operator " + self.op + + " has not been implemented yet.") + + __str__ = __repr__ diff --git a/dace/frontend/octave/ast_expression.py b/dace/frontend/octave/ast_expression.py new file mode 100644 index 0000000000..c35b5f0263 --- /dev/null +++ b/dace/frontend/octave/ast_expression.py @@ -0,0 +1,304 @@ +import dace + +from .ast_node import AST_Node + + +class AST_UnaryExpression(AST_Node): + def __init__(self, context, arg, op, order): + AST_Node.__init__(self, context) + self.arg = arg + self.op = op + self.order = order # can be "pre" or "post" (++A vs A++) + self.children = [self.arg] + + def __repr__(self): + return "AST_UnaryExpression(" + str(self.arg) + ", " + str(self.op) + \ + ", " + str(self.order) + ")" + + def get_children(self): + return [self.arg] + + def replace_child(self, old, new): + if self.arg == old: + self.arg = new + else: + raise ValueError(str(old) + " is not a child of " + str(self)) + + def specialize(self): + from .ast_values import AST_Constant + # -A is syntactic sugar for -1*A + if (self.op == "-") and isinstance(self.arg, AST_Constant): + new = AST_Constant(self.context, -self.arg.get_value()) + new.next = self.next + new.prev = self.prev + new.parent = self.parent + return new + elif (self.op == "-"): + new = AST_BinExpression(self.context, self.arg, + AST_Constant(None, -1), "*") + new.next = self.next + new.prev = self.prev + new.parent = self.parent + return new + + __str__ = __repr__ + + +class AST_BinExpression(AST_Node): + def __init__(self, context, lhs, rhs, op): + AST_Node.__init__(self, context) + self.lhs = lhs + self.rhs = rhs + self.op = op + self.children = [self.lhs, self.rhs] + + def provide_parents(self, parent): + self.parent = parent + self.lhs.provide_parents(self) + self.rhs.provide_parents(self) + + def get_children(self): + return [self.lhs, self.rhs] + + def replace_child(self, old, new): + if self.lhs == old: + self.lhs = new + if self.rhs == old: + self.rhs = new + + def __repr__(self): + return "AST_BinExpression(" + str(self.lhs) + ", " + str( + self.op) + ", " + str(self.rhs) + ")" + + def get_dims(self): + left_dims = self.lhs.get_dims() + right_dims = self.rhs.get_dims() + if len(left_dims) > 2 or len(right_dims) > 2: + raise ValueError("Only 2D matrices can be multiplied") + outdims = None + if self.op == "*": + # if lhs is a scalar, outdims = rhs + if left_dims == [1]: + outdims = right_dims + # elif rhs is a scalar, outdims = lhs + elif right_dims == [1]: + outdims = left_dims + # elif lhs is a matrix, check if dims match, compute new outdims + elif left_dims[1] != right_dims[0]: + print(str(left_dims) + "type: " + str(type(left_dims[1]))) + print(str(right_dims) + "type: " + str(type(right_dims[0]))) + raise ValueError("Dims do not match!") + else: + outdims = [left_dims[0], right_dims[1]] + elif self.op == "+" or self.op == "-" or self.op == "/": + # if lhs is a scalar, outdims = rhs + if left_dims == [1]: + outdims = right_dims + # elif rhs is a scalar, outdims = lhs + elif right_dims == [1]: + outdims = left_dims + # elif lhs is a matrix, check if dims match, compute new outdims + elif left_dims != right_dims: + raise ValueError("Dimensions do not match") + else: + outdims = left_dims + else: + raise NotImplementedError("Unhandled binary operator: " + + str(self.op)) + if outdims == [1, 1]: + outdims = [1] + return outdims + + def get_basetype(self): + # The basetype of a binary expression should be the more accurate + # type of lhs and rhs + return dace.types.float64 + + def matrix2d_scalar(self, sdfg, state, op): + lhs_dims = self.lhs.get_dims() + rhs_dims = self.rhs.get_dims() + M = str(lhs_dims[-2]) + N = str(lhs_dims[-1]) + A = self.lhs.get_datanode(sdfg, state) + B = self.rhs.get_datanode(sdfg, state) + C = self.get_datanode(sdfg, state) + + s = sdfg.nodes()[state] + map_entry, map_exit = s.add_map('M' + op + 'M', + dict(i='0:' + M, j='0:' + N)) + map_entry._in_connectors.add('IN_1') + map_entry._in_connectors.add('IN_2') + map_entry._out_connectors.add('OUT_1') + map_entry._out_connectors.add('OUT_2') + s.add_edge(A, None, map_entry, 'IN_1', + dace.memlet.Memlet.simple(A, '0:' + N + ',0:' + M)) + s.add_edge(B, None, map_entry, 'IN_2', dace.memlet.Memlet.simple( + B, '0')) + tasklet = s.add_tasklet(op, {'a', 'b'}, {'c'}, 'c = a' + op + 'b') + s.add_edge(map_entry, "OUT_1", tasklet, "a", + dace.memlet.Memlet.simple(A, 'i,j')) + s.add_edge(map_entry, "OUT_2", tasklet, "b", + dace.memlet.Memlet.simple(B, '0')) + s.add_edge(tasklet, "c", map_exit, None, + dace.memlet.Memlet.simple(C, 'i,j')) + s.add_edge(map_exit, None, C, None, + dace.memlet.Memlet.simple(C, '0:' + N + ', 0:' + M)) + + def matrix2d_matrix2d_mult(self, sdfg, state): + lhs_dims = self.lhs.get_dims() + rhs_dims = self.rhs.get_dims() + A = self.lhs.get_datanode(sdfg, state) + B = self.rhs.get_datanode(sdfg, state) + C = self.get_datanode(sdfg, state) + + M = str(lhs_dims[-1]) + N = str(lhs_dims[-1]) + K = str(rhs_dims[-1]) + + s = sdfg.nodes()[state] + map_entry, map_exit = s.add_map( + 'MMM', dict(i='0:' + M, j='0:' + N, k='0:' + K)) + map_entry._in_connectors.add('IN_1') + map_entry._in_connectors.add('IN_2') + map_entry._out_connectors.add('OUT_1') + map_entry._out_connectors.add('OUT_2') + s.add_edge(A, None, map_entry, 'IN_1', + dace.memlet.Memlet.simple(A, '0:' + M + ',0:' + K)) + s.add_edge(B, None, map_entry, 'IN_2', + dace.memlet.Memlet.simple(B, '0:' + K + ', 0:' + N)) + tasklet = s.add_tasklet('mult', {'a', 'b'}, {'c'}, 'c = a*b') + s.add_edge(map_entry, "OUT_1", tasklet, "a", + dace.memlet.Memlet.simple(A, 'i,k')) + s.add_edge(map_entry, "OUT_2", tasklet, "b", + dace.memlet.Memlet.simple(B, 'k,j')) + tmpname = self.get_new_tmpvar(sdfg) + sdfg.add_transient(tmpname, [M, N, K], self.get_basetype()) + tmp = s.add_access(tmpname) + s.add_edge(tasklet, "c", map_exit, None, + dace.memlet.Memlet.simple(tmp, 'i,j,k')) + rednode = s.add_reduce('lambda a,b: a+b', (2, ), 0) + s.add_edge( + map_exit, None, tmp, None, + dace.memlet.Memlet.simple(tmp, '0:' + M + ',0:' + N + ',0:' + K)) + s.add_edge( + tmp, None, rednode, None, + dace.memlet.Memlet.simple(tmp, '0:' + M + ',0:' + N + ',0:' + K)) + s.add_edge(rednode, None, C, None, + dace.memlet.Memlet.simple(C, '0:' + M + ',0:' + N)) + + def vec_mult_vect(self, sdfg, state, op): + lhs_dims = self.lhs.get_dims() + rhs_dims = self.rhs.get_dims() + A = self.lhs.get_datanode(sdfg, state) + B = self.rhs.get_datanode(sdfg, state) + C = self.get_datanode(sdfg, state) + + N = str(lhs_dims[-1]) + + s = sdfg.nodes()[state] + map_entry, map_exit = s.add_map('VVM', dict(i='0:' + N)) + map_entry._in_connectors.add('IN_1') + map_entry._in_connectors.add('IN_2') + map_entry._out_connectors.add('OUT_1') + map_entry._out_connectors.add('OUT_2') + s.add_edge(A, None, map_entry, 'IN_1', + dace.memlet.Memlet.simple(A, '0:' + N)) + s.add_edge(B, None, map_entry, 'IN_2', + dace.memlet.Memlet.simple(B, '0:' + N)) + tasklet = s.add_tasklet('mult', {'a', 'b'}, {'c'}, 'c = a*b') + s.add_edge(map_entry, "OUT_1", tasklet, "a", + dace.memlet.Memlet.simple(A, '0,i')) + s.add_edge(map_entry, "OUT_2", tasklet, "b", + dace.memlet.Memlet.simple(B, 'i,0')) + tmpname = self.get_new_tmpvar(sdfg) + sdfg.add_transient(tmpname, [N], self.get_basetype()) + tmp = s.add_access(tmpname) + s.add_edge(tasklet, "c", map_exit, None, + dace.memlet.Memlet.simple(tmp, 'i')) + rednode = s.add_reduce('lambda a,b: a+b', (0, ), 0) + s.add_edge(map_exit, None, tmp, None, + dace.memlet.Memlet.simple(tmp, '0:' + N)) + s.add_edge(tmp, None, rednode, None, + dace.memlet.Memlet.simple(tmp, '0:' + N)) + s.add_edge(rednode, None, C, None, dace.memlet.Memlet.simple(C, '0')) + + def matrix2d_matrix2d_plus_or_minus(self, sdfg, state, op): + lhs_dims = self.lhs.get_dims() + rhs_dims = self.rhs.get_dims() + M = str(lhs_dims[-2]) + N = str(lhs_dims[-1]) + A = self.lhs.get_datanode(sdfg, state) + B = self.rhs.get_datanode(sdfg, state) + C = self.get_datanode(sdfg, state) + + s = sdfg.nodes()[state] + map_entry, map_exit = s.add_map('M' + op + 'M', + dict(i='0:' + M, j='0:' + N)) + map_entry._in_connectors.add('IN_1') + map_entry._in_connectors.add('IN_2') + map_entry._out_connectors.add('OUT_1') + map_entry._out_connectors.add('OUT_2') + s.add_edge(A, None, map_entry, 'IN_1', + dace.memlet.Memlet.simple(A, '0:' + N + ',0:' + M)) + s.add_edge(B, None, map_entry, 'IN_2', + dace.memlet.Memlet.simple(B, '0:' + N + ', 0:' + M)) + tasklet = s.add_tasklet(op, {'a', 'b'}, {'c'}, 'c = a' + op + 'b') + s.add_edge(map_entry, "OUT_1", tasklet, "a", + dace.memlet.Memlet.simple(A, 'i,j')) + s.add_edge(map_entry, "OUT_2", tasklet, "b", + dace.memlet.Memlet.simple(B, 'i,j')) + s.add_edge(tasklet, "c", map_exit, None, + dace.memlet.Memlet.simple(C, 'i,j')) + s.add_edge(map_exit, None, C, None, + dace.memlet.Memlet.simple(C, '0:' + N + ', 0:' + M)) + + def scalar_scalar(self, sdfg, state, op): + A = self.lhs.get_datanode(sdfg, state) + B = self.rhs.get_datanode(sdfg, state) + C = self.get_datanode(sdfg, state) + + s = sdfg.nodes()[state] + tasklet = s.add_tasklet(op, {'a', 'b'}, {'c'}, 'c = a' + op + 'b') + s.add_edge(A, None, tasklet, 'a', dace.memlet.Memlet.simple(A, '0')) + s.add_edge(B, None, tasklet, 'b', dace.memlet.Memlet.simple(B, '0')) + s.add_edge(tasklet, "c", C, None, dace.memlet.Memlet.simple(C, '0')) + + def generate_code(self, sdfg, state): + # Generate code for the lhs and rhs + self.lhs.generate_code(sdfg, state) + self.rhs.generate_code(sdfg, state) + + # Add a new variable to hold the result of this expression + dims = self.get_dims() + basetype = self.get_basetype() + name = self.get_name_in_sdfg(sdfg) + sdfg.add_transient(name, dims, basetype, debuginfo=self.context) + print("The result of " + str(self) + " will be stored in " + str(name)) + + lhs_dims = self.lhs.get_dims() + rhs_dims = self.rhs.get_dims() + + if rhs_dims == [1, 1] or rhs_dims == [1]: + if lhs_dims == [1, 1] or lhs_dims == [1]: + self.scalar_scalar(sdfg, state, self.op) + else: + self.matrix2d_scalar(sdfg, state, self.op) + return + if lhs_dims[0] == 1 and rhs_dims[1] == 1 and self.op == "*": + self.vec_mult_vect(sdfg, state, self.op) + elif lhs_dims == [1, 1] or lhs_dims == [1]: + raise NotImplementedError( + "Binary expression with scalar on lhs not implemented: " + + str(self) + ", lhs dims: " + str(lhs_dims) + ", rhs dims: " + + str(rhs_dims)) + else: + if self.op == "*": + self.matrix2d_matrix2d_mult(sdfg, state) + elif self.op == "-" or self.op == "+": + self.matrix2d_matrix2d_plus_or_minus(sdfg, state, self.op) + else: + raise NotImplementedError("Binary expression with two " + + "matrices and op=" + str(self.op) + + " not implemented") + + __str__ = __repr__ diff --git a/dace/frontend/octave/ast_function.py b/dace/frontend/octave/ast_function.py new file mode 100644 index 0000000000..aaa7f5055c --- /dev/null +++ b/dace/frontend/octave/ast_function.py @@ -0,0 +1,371 @@ +import dace +import copy + +from .ast_node import AST_Node + + +class AST_EndFunc(AST_Node): + def __init__(self, context): + AST_Node.__init__(self, context) + + def get_children(self): + return [] + + def replace_child(self, old, new): + raise ValueError("AST_EndFunc has no children") + + def generate_code(self, sdfg, state): + pass + + def __repr__(self): + return "AST_EndFunc()" + + +class AST_Function(AST_Node): + def __init__(self, context, name, args, retvals): + AST_Node.__init__(self, context) + self.name = name + self.args = args + self.retvals = retvals + self.statements = None + + def __repr__(self): + return "AST_Function(" + self.name.get_name() + ", args=[" + ", ".join( + [str(x) for x in self.args]) + "], retvals=[" + ", ".join( + [str(x) for x in self.retvals]) + "])" + + def set_statements(self, stmtlist): + self.statements = AST_Statements(None, stmtlist) + self.statements.provide_parents(self) + + def get_children(self): + ret = [] + ret.append(self.name) + ret += self.args + ret += self.retvals + return ret + + def replace_child(self, old, new): + if self.name == old: + self.name = new + elif old in self.args: + newargs = [new if x == old else x for x in self.args] + self.args = newargs + elif old in self.retvals: + newret = [new if x == old else x for x in self.retvals] + self.retvals = newret + + def generate_code(self, sdfg, state): + # This does not do anything, since we inline functions at the call site, + # so the code generation happens there. + pass + + __str__ = __repr__ + + +class AST_Argument(AST_Node): + def __init__(self, context, name, default=None): + AST_Node.__init__(self, context) + self.name = name + self.default = default + + def get_children(self): + ret = [self.name] + if self.default is not None: + ret += [self.default] + return ret + + def __repr__(self): + return "AST_Argument(" + self.name.get_name() + ", default=" + str( + self.default) + ")" + + __str__ = __repr__ + + +class AST_BuiltInFunCall(AST_Node): + def __init__(self, context, funname, args): + AST_Node.__init__(self, context) + self.funname = funname + self.args = args + + def __repr__(self): + return "AST_BuiltInFunCall(" + str(self.funname) + ", " + str( + self.args) + ")" + + def get_children(self): + retval = self.args[:] + retval.append(self.funname) + return retval + + def replace_child(self, old, new): + if old == self.funname: + self.funname = new + return + if old in self.args: + newargs = [new if x == old else x for x in self.args] + self.args = newargs + + def get_basetype(self): + # For now assume it is always double + return dace.types.float64 + + def get_dims(self): + from .ast_matrix import AST_Matrix + dims = None + if self.funname.get_name() in ["zeros", "ones", "rand", "eye"]: + # The dimensions for these functions are the arguments, but we + # need to convert them to values, if we cannot they are symbolic + for arg in self.args: + if not arg.is_constant(): + + return self.args + if isinstance(self.args[0], AST_Matrix): + dims = self.args[0].get_values_row_major() + else: + dims = [self.args[0].get_value(), self.args[1].get_value()] + elif self.funname.get_name() in ["sqrt"]: + return self.args[0].get_dims() + elif self.funname.get_name() in ["length"]: + dims = [1] + if dims is None: + raise NotImplementedError("Cannot infer dimensions for " + + str(self)) + return dims + + def generate_code(self, sdfg, state): + + # TODO: rand has options for setting seed/state and controlling + # accuracy. We only deal with the simple use-case for now. + + if self.funname.get_name() in ["sqrt"]: + dims = self.get_dims() + name = self.get_name_in_sdfg(sdfg) + basetype = dace.types.float64 + sdfg.add_transient(name, dims, basetype, debuginfo=self.context) + print("The result of expr " + str(self) + " will be stored in " + + str(name)) + + self.args[0].generate_code(sdfg, state) + + resnode = self.get_datanode(sdfg, state) + if len(dims) == 1: + s = sdfg.nodes()[state] + A = self.args[0].get_datanode(sdfg, state) + tasklet = sdfg.nodes()[state].add_tasklet( + 'sqrt', {'in'}, {'out'}, "out=sqrt(in);", + dace.Language.CPP) + s.add_edge(A, None, tasklet, "in", + dace.memlet.Memlet.from_array(A.data, A.desc(sdfg))) + s.add_edge( + tasklet, "out", resnode, None, + dace.memlet.Memlet.from_array(resnode.data, + resnode.desc(sdfg))) + elif len(dims) == 2: + M = str(dims[0]) + N = str(dims[1]) + + men, mex = sdfg.nodes()[state].add_map( + self.funname.get_name() + 'map', + dict(i="0:" + N, j="0:" + M)) + tasklet = None + s = sdfg.nodes()[state] + A = self.args[0].get_datanode(sdfg, state) + s.add_edge(A, None, men, None, + dace.memlet.Memlet.from_array(A.data, A.desc(sdfg))) + tasklet = sdfg.nodes()[state].add_tasklet( + 'sqrt', {'in'}, {'out'}, "out=sqrt(in);", + dace.Language.CPP) + s.add_edge(men, None, tasklet, "in", + dace.memlet.Memlet.simple(A, 'i,j')) + s.add_edge(tasklet, "out", mex, None, + dace.memlet.Memlet.simple(resnode, 'i,j')) + s.add_edge( + mex, None, resnode, None, + dace.memlet.Memlet.simple(resnode, '0:' + N + ',0:' + M)) + else: + raise ValueError( + "sqrt of tensors with more than 2 dims not supported") + + if self.funname.get_name() in ["zeros", "rand"]: + dims = self.get_dims() + name = self.get_name_in_sdfg(sdfg) + basetype = dace.types.float64 + sdfg.add_transient(name, dims, basetype, debuginfo=self.context) + print("The result of expr " + str(self) + " will be stored in " + + str(name)) + + # Add a map over all dimensions with a tasklet that will initialize + # the array to random values (0,1). + + if len(dims) > 2: + raise NotImplementedError( + "Code generation only implemented for 2 arguments") + + resnode = self.get_datanode(sdfg, state) + M = str(dims[0]) + N = str(dims[1]) + + s = sdfg.nodes()[state] + men, mex = s.add_map(self.funname.get_name() + 'map', + dict(i="0:" + N, j="0:" + M)) + tasklet = None + if self.funname.get_name() == "zeros": + tasklet = sdfg.nodes()[state].add_tasklet( + 'zero', {}, {'out'}, "out=0") + s.add_edge(men, None, tasklet, None, dace.memlet.EmptyMemlet()) + elif self.funname.get_name() == "rand": + tasklet = sdfg.nodes()[state].add_tasklet( + 'rand', {}, {'out'}, "out=drand48()") + s.add_edge(men, None, tasklet, None, dace.memlet.EmptyMemlet()) + elif self.funname.get_name() == "sqrt": + A = self.args[0].get_datanode(sdfg, state) + tasklet = sdfg.nodes()[state].add_tasklet( + 'sqrt', {'in'}, {'out'}, "out=sqrt(in)") + s.add_edge(men, None, tasklet, "in", + dace.memlet.Memlet.simple(A, 'i,j')) + else: + raise NotImplementedError("Code generation for " + + str(self.funname.get_name()) + + " is not implemented.") + s = sdfg.nodes()[state] + s.add_edge(tasklet, "out", mex, None, + dace.memlet.Memlet.simple(resnode, 'i,j')) + s.add_edge( + mex, None, resnode, None, + dace.memlet.Memlet.simple(resnode, '0:' + N + ',0:' + M)) + + def specialize(self): + from .ast_matrix import AST_Matrix, AST_Matrix_Row + from .ast_values import AST_Constant, AST_Ident + + # First try to specialize the arguments (for constant propagation) + for c in self.get_children(): + n = c.specialize() + while n is not None: + n.replace_parent(c.get_parent()) + self.replace_child(old=c, new=n) + c = n + n = n.specialize() + for c in self.get_children(): + if isinstance(c, AST_Ident): + if isinstance(c.get_propagated_value(), AST_Constant): + n = copy.deepcopy(c.get_propagated_value()) + self.replace_child(old=c, new=n) + + # If this is a call to zeros, ones, or eye, and the arguments are + # constants, we can generate a constant expression. `length` is a + # special case, since for now we require that all dimensions are + # compile time constants. + + if self.funname.get_name() == "length": + vardef = self.search_vardef_in_scope(self.args[0].get_name()) + if vardef is None: + raise ValueError("No definition found for " + + self.args[0].get_name()) + dims = vardef.get_dims() + length = max(dims) + return AST_Constant(None, length) + + if not self.funname.get_name() in ["zeros", "ones", "eye"]: + return None + + for arg in self.args: + if not arg.is_constant(): + return None + + # The args to those functions can be supplied as a 1x2 matrix or + # two seperate values, the semantics are the same. + dims = [] + if isinstance(self.args, AST_Matrix): + dims = self.args.get_values_row_major() + else: + dims = [x.get_value() for x in self.args] + + rows = [] + for r in range(0, dims[0]): + rowelems = [] + for c in range(0, dims[1]): + zero = AST_Constant(self.context, 0) + one = AST_Constant(self.context, 1) + if self.funname.get_name() == "zeros": + rowelems.append(zero) + if self.funname.get_name() == "ones": + rowelems.append(one) + if self.funname.get_name() == "eye": + if r == c: + rowelems.append(one) + else: + rowelems.append(zero) + rows.append(AST_Matrix_Row(self.context, rowelems)) + res = AST_Matrix(self.context, rows) + res.provide_parents(self.get_parent()) + res.next = self.next + res.prev = self.prev + return res + + __str__ = __repr__ + + +class AST_FunCall(AST_Node): + # NOTE: When parsing, array references, i.e., A(1,2) is the same as + # function calls, so after parsing this node will be used for both, + # and we resolve this later. + def __init__(self, context, funname, args): + AST_Node.__init__(self, context) + self.funname = funname + self.args = args + + def get_children(self): + retval = self.args[:] + retval.append(self.funname) + return retval + + def replace_child(self, old, new): + if old == self.funname: + self.funname = new + return + if old in self.args: + newargs = [new if x == old else x for x in self.args] + self.args = newargs + + def __repr__(self): + return "AST_FunCall(" + str(self.funname) + ", " + str(self.args) + ")" + + def specialize(self): + # This function will be called after we have the complete AST. + # Thus we know if this is a real function call or an array access. + # If it is a function call, differentiate between built-in functions + # and user-defined ones. + from .ast_arrayaccess import AST_ArrayAccess + + if self.funname.get_name() in [ + "zeros", "eye", "rand", "ones", "length", "sqrt" + ]: + new = AST_BuiltInFunCall(self.context, self.funname, self.args) + new.next = self.next + new.prev = self.prev + new.parent = self.parent + for c in new.get_children(): + c.provide_parents(new) + return new + else: + # find the definition of self.funname, if it is anything else + # than an AST_Function this is an array subaccess + vardef = self.search_vardef_in_scope(self.funname.get_name()) + if vardef == None: + raise ValueError("No definition found for " + + self.funname.get_name() + " searching from " + + str(self)) + if isinstance(vardef, AST_Function): + return None + else: + new = AST_ArrayAccess(self.context, self.funname, self.args) + new.next = self.next + new.prev = self.prev + new.parent = self.parent + for c in new.get_children(): + c.provide_parents(new) + return new + return None + + __str__ = __repr__ diff --git a/dace/frontend/octave/ast_loop.py b/dace/frontend/octave/ast_loop.py new file mode 100644 index 0000000000..f5be68b47d --- /dev/null +++ b/dace/frontend/octave/ast_loop.py @@ -0,0 +1,202 @@ +import dace + +from .ast_node import AST_Node + + +class AST_ForLoop(AST_Node): + def __init__(self, context, var, initializer, stmts): + AST_Node.__init__(self, context) + self.var = var + self.initializer = initializer + self.stmts = stmts + + def __repr__(self): + return "AST_ForLoop(" + str(self.var) + " = " + str( + self.initializer) + ", stmts: {\n" + str(self.stmts) + "\n})" + + def get_children(self): + return [self.var, self.initializer, self.stmts] + + def replace_child(self, old, new): + if old == self.var: + self.var = new + return + if old == self.initializer: + self.initializer = new + return + if old == self.stmts: + self.stmts = new + return + raise ValueError("The child " + str(old) + " is not a child of " + + str(self)) + + def generate_code(self, sdfg, state): + from .ast_range import AST_RangeExpression + # This ignores matlab semantics and only works for loops of the form + # for var = start:end where start and end are expressions which + # evaluate to scalars. + if isinstance(self.initializer, AST_RangeExpression): + # Generate the initializer: + # lhs and rhs of the iteration range as two transients, and a + # transient for i (we also have a symbol for i which states will + # use) + initializer_state_num = state + s = sdfg.nodes()[state] + self.initializer.lhs.generate_code(sdfg, state) + lhs_node = self.initializer.lhs.get_datanode(sdfg, state) + self.initializer.rhs.generate_code(sdfg, state) + rhs_node = self.initializer.rhs.get_datanode(sdfg, state) + sdfg.add_transient( + self.var.get_name_in_sdfg(sdfg), [1], + self.initializer.lhs.get_basetype()) + var_node = s.add_access(self.var.get_name_in_sdfg(sdfg)) + s.add_edge( + lhs_node, None, var_node, None, + dace.memlet.Memlet.from_array(var_node.data, + var_node.desc(sdfg))) + loop_guard_var = '_loopiter_' + str(state) + loop_end_var = '_loopend_' + str(state) + + # Generate guard state, write loop iter symbol into loop iter + # datanode + guard_state_num = initializer_state_num + 1 + s_guard = sdfg.add_state('s' + str(guard_state_num)) + task = s_guard.add_tasklet('reinitloopiter', {}, {'out'}, + "out=" + loop_guard_var) + + if self.var.get_name_in_sdfg(sdfg) not in sdfg.arrays: + sdfg.add_transient( + self.var.get_name_in_sdfg(sdfg), [1], + self.initializer.lhs.get_basetype()) + trans = s_guard.add_access(self.var.get_name_in_sdfg(sdfg)) + # Workaround until "condition for putting a variable as top-level + # doesn't take inter-state edges into account" is solved. + # When fixed, the line below can be removed. + self.initializer.rhs.generate_code(sdfg, guard_state_num) + + s_guard.add_edge( + task, 'out', trans, None, + dace.memlet.Memlet.from_array(trans.data, trans.desc(sdfg))) + lg_init = dace.graph.edges.InterstateEdge( + assignments={ + loop_guard_var: + self.var.get_name_in_sdfg(sdfg) + '[0]', + loop_end_var: + self.initializer.rhs.get_name_in_sdfg(sdfg) + '[0]' + }) + sdfg.add_edge(sdfg.nodes()[state], s_guard, lg_init) + + # Add state for each statement within the for loop + prev = s_guard + for s in self.stmts.statements: + state = len(sdfg.nodes()) + newstate = dace.SDFGState( + "s" + str(state), sdfg, debuginfo=s.context) + sdfg.add_node(newstate) + last_state = s.generate_code(sdfg, state) + if last_state is None: last_state = state + if prev != s_guard: + edge = dace.graph.edges.InterstateEdge() + sdfg.add_edge(prev, newstate, edge) + else: + edge = dace.graph.edges.InterstateEdge( + condition=dace.properties.CodeProperty.from_string( + loop_guard_var + " <= " + loop_end_var, + language=dace.types.Language.Python)) + sdfg.add_edge(prev, newstate, edge) + prev = sdfg.nodes()[last_state] + + # Create inter-state back-edge + edge = dace.graph.edges.InterstateEdge( + assignments={loop_guard_var: loop_guard_var + '+1'}) + sdfg.add_edge(prev, s_guard, edge) + + # Create the loop exit state + state = len(sdfg.nodes()) + s_lexit = dace.SDFGState( + "s" + str(state), sdfg, debuginfo=s.context) + lend_val = str(self.initializer.get_dims()[-1]) + for_exit = dace.graph.edges.InterstateEdge( + condition=dace.properties.CodeProperty.from_string( + loop_guard_var + " > " + loop_end_var, + language=dace.types.Language.Python)) + sdfg.add_edge(s_guard, s_lexit, for_exit) + + return state + + else: + raise NotImplementedError( + "Loops over anything but ranges are not implemented.") + + def generate_code_proper(self, sdfg, state): + # This follows matlab semantics, i.e., a loop iterates over the columns + # of a matrix. This does not work well for sdfgs for all but the + # simplest case (a matrix which is a compile time constant, ie. 1:10). + # To support programs like Cholesky, we try to transform the matlab for + # loop into a C-style loop, this is implemented in generate_code(). + + # Generate the initializer: + # Each iteration of the for loop will use one column + initializer_state_num = state + self.initializer.generate_code(sdfg, state) + loop_guard_var = '_lg_' + str(state) + # Generate an (empty) guard state + guard_state_num = initializer_state_num + 1 + s_guard = sdfg.add_state('s' + str(guard_state_num)) + lg_init = dace.graph.edges.InterstateEdge( + assignments={loop_guard_var: '0'}) + sdfg.add_edge(sdfg.nodes()[state], s_guard, lg_init) + + # Read a column of the initializer + get_initializer_state_num = guard_state_num + 1 + s_getinit = sdfg.add_state('s' + str(get_initializer_state_num)) + initializer_name = self.initializer.get_name_in_sdfg(sdfg) + loopvar_name = self.var.get_name_in_sdfg(sdfg) + dims = self.initializer.get_dims()[:1] + sdfg.add_transient(loopvar_name, dims, self.initializer.get_basetype()) + part = s_getinit.add_access(loopvar_name) + sdfg.add_transient(initializer_name, self.initializer.get_dims(), + self.initializer.get_basetype()) + full = s_getinit.add_read(initializer_name) + s_getinit.add_edge(full, None, part, None, + dace.memlet.Memlet.simple(initializer_name, 'i,0')) + + # Add edge from guard to getinit + lend_val = str(self.initializer.get_dims()[-1]) + for_entry = dace.graph.edges.InterstateEdge( + condition=dace.properties.CodeProperty.from_string( + loop_guard_var + " < " + lend_val, + language=dace.types.Language.Python)) + sdfg.add_edge(s_guard, s_getinit, for_entry) + + # Add state for each statement within the for loop + prev = s_getinit + for s in self.stmts.statements: + state = len(sdfg.nodes()) + newstate = dace.SDFGState( + "s" + str(state), sdfg, debuginfo=s.context) + sdfg.add_node(newstate) + last_state = s.generate_code(sdfg, state) + if last_state is None: last_state = state + edge = dace.graph.edges.InterstateEdge() + sdfg.add_edge(prev, newstate, edge) + prev = sdfg.nodes()[last_state] + + # Create inter-state back-edge + edge = dace.graph.edges.InterstateEdge( + assignments={loop_guard_var: loop_guard_var + '+1'}) + sdfg.add_edge(prev, s_guard, edge) + + # Create the loop exit state + state = len(sdfg.nodes()) + s_lexit = dace.SDFGState("s" + str(state), sdfg, debuginfo=s.context) + lend_val = str(self.initializer.get_dims()[-1]) + for_exit = dace.graph.edges.InterstateEdge( + condition=dace.properties.CodeProperty.from_string( + loop_guard_var + " >= " + lend_val, + language=dace.types.Language.Python)) + sdfg.add_edge(s_guard, s_lexit, for_exit) + + return state + + __str__ = __repr__ diff --git a/dace/frontend/octave/ast_matrix.py b/dace/frontend/octave/ast_matrix.py new file mode 100644 index 0000000000..1fc3c656d5 --- /dev/null +++ b/dace/frontend/octave/ast_matrix.py @@ -0,0 +1,214 @@ +from .ast_node import AST_Node +from .ast_values import AST_Constant + +import dace + + +class AST_Matrix_Row(AST_Node): + def __init__(self, context, elements): + AST_Node.__init__(self, context) + self.elements = elements + if not isinstance(self.elements, list): + raise ValueError( + "AST_Matrix_Row() expects a list of elements, got " + + str(type(self.elements))) + + def provide_parents(self, parent): + self.parent = parent + for e in self.elements: + e.provide_parents(self) + + def __repr__(self): + return "AST_MatrixRow(" + ", ".join([str(i) + for i in self.elements]) + ")" + + def get_dims(self): + return len(self.elements) + + def get_children(self): + return self.elements[:] + + def replace_child(self, old, new): + newelems = [new if x == old else x for x in self.elements] + self.elements = newelems + + def is_constant(self): + for r in self.elements: + if not isinstance(r, AST_Constant): + return False + return True + + def __getitem__(self, item): + if item >= len(self): + raise IndexError("AST_Matrix_Row index out of range") + return self.elements[item] + + def __len__(self): + return len(self.elements) + + __str__ = __repr__ + + +class AST_Matrix(AST_Node): + def __init__(self, context, rows): + AST_Node.__init__(self, context) + self.rows = rows + self.children = self.rows + if not isinstance(self.rows, list): + raise ValueError("AST_Matrix() expects a list of rows, got " + + str(type(self.rows))) + for r in self.rows: + if not isinstance(r, AST_Matrix_Row): + raise ValueError("AST_Matrix() expects a list of rows, got " + + str(r) + " of type " + str(type(r))) + + def __repr__(self): + return "AST_Matrix(" + ", ".join([str(i) for i in self.rows]) + ")" + + def provide_parents(self, parent): + self.parent = parent + for e in self.rows: + e.provide_parents(self) + + def get_dims(self): + dims = -1 + for r in self.rows: + if (dims > 0) and (r.get_dims() != dims): + raise ValueError( + "Matrices with unequal row lengths are currently not " + "supported.") + else: + dims = r.get_dims() + return [len(self.rows), dims] + + def get_basetype(self): + # This should be double, unless we have a complex inside, for now just + # return double. + return dace.types.float64 + + def is_constant(self): + for r in self.rows: + if not r.is_constant(): + return False + return True + + def get_values_row_major(self): + values = [] + for r in self.rows: + for c in r: + if isinstance(c, AST_Constant): + values.append(c.get_value()) + else: + values.append(0) + return values + + def generate_code(self, sdfg, state): + if self.is_constant(): + name = self.get_name_in_sdfg(sdfg) + dims = self.get_dims() + basetype = self.get_basetype() + sdfg.add_transient(name, dims, basetype) + trans = sdfg.nodes()[state].add_access(name) + # Add map over dims, and a taklet which puts the values into the + # transient. + arrlen = 1 + for d in dims: + arrlen *= d + vals = self.get_values_row_major() + code = "constexpr double VALUES[" + str(arrlen) + "] = {" + code += ", ".join(str(i) for i in vals) + "};\n" + code += "out[i] = VALUES[i];" + + tasklet = sdfg.nodes()[state].add_tasklet('init', {}, {'out'}, + code, dace.Language.CPP) + me, mx = sdfg.nodes()[state].add_map( + 'init', dict(i='0:' + str(arrlen))) + sdfg.nodes()[state].add_edge(me, None, tasklet, None, + dace.memlet.EmptyMemlet()) + sdfg.nodes()[state].add_edge( + tasklet, "out", mx, None, + dace.memlet.Memlet.from_array(trans.data, trans.desc(sdfg))) + sdfg.nodes()[state].add_edge( + mx, None, trans, None, + dace.memlet.Memlet.from_array(trans.data, trans.desc(sdfg))) + + print("The const expr " + str(self) + " will be stored in " + + str(name) + ", values are: " + + str(self.get_values_row_major())) + else: + raise ValueError( + "Non-constant matrices are currently not supported") + + def get_children(self): + return self.rows[:] + + def replace_child(self, old, new): + newrows = [new if x == old else x for x in self.rows] + self.rows = newrows + + __str__ = __repr__ + + +class AST_Transpose(AST_Node): + def __init__(self, context, arg, op): + AST_Node.__init__(self, context) + self.arg = arg + self.op = op + + def __repr__(self): + return "AST_Transpose(" + str(self.arg) + ", " + str(self.op) + ")" + + def get_children(self): + return [self.arg] + + def get_dims(self): + dims = self.arg.get_dims() + return dims[::-1] + + def get_basetype(self): + return self.arg.get_basetype() + + def generate_code(self, sdfg, state): + dims = self.get_dims() + name = self.get_name_in_sdfg(sdfg) + basetype = self.get_basetype() + if basetype.is_complex(): + raise NotImplementedError( + "Transpose of complex matrices not implemented (we might need " + "to conjugate)") + if len(dims) != 2: + raise NotImplementedError( + "Transpose only implemented for 2D matrices") + sdfg.add_transient(name, dims, basetype, debuginfo=self.context) + + resnode = self.get_datanode(sdfg, state) + self.arg.generate_code(sdfg, state) + A = self.arg.get_datanode(sdfg, state) + + N = str(dims[0]) + M = str(dims[1]) + s = sdfg.nodes()[state] + map_entry, map_exit = s.add_map('transpose', + dict(i='0:' + N, j='0:' + M)) + map_entry._in_connectors.add('IN_1') + map_entry._out_connectors.add('OUT_1') + s.add_edge(A, None, map_entry, 'IN_1', + dace.memlet.Memlet.simple(A, '0:' + N + ',0:' + M)) + tasklet = s.add_tasklet('identity', {'a'}, {'out'}, 'out = a') + s.add_edge(map_entry, "OUT_1", tasklet, "a", + dace.memlet.Memlet.simple(A, 'i,j')) + s.add_edge(tasklet, "out", map_exit, None, + dace.memlet.Memlet.simple(resnode, 'j,i')) + s.add_edge(map_exit, None, resnode, None, + dace.memlet.Memlet.simple(resnode, '0:' + M + ', 0:' + N)) + print("The result of expr " + str(self) + " will be stored in " + + str(name)) + + def replace_child(self, old, new): + if old == self.arg: + self.arg = new + return + raise ValueError("The child " + str(old) + " is not a child of " + + str(self)) + + __str__ = __repr__ diff --git a/dace/frontend/octave/ast_node.py b/dace/frontend/octave/ast_node.py new file mode 100644 index 0000000000..62ba8d69e9 --- /dev/null +++ b/dace/frontend/octave/ast_node.py @@ -0,0 +1,307 @@ +import re +import dace +from collections import OrderedDict + + +class AST_Node(): + def __init__(self, context): + self.context = context + self.name = None # Name of the variable holding the result in the SDFG + self.parent = None + self.next = None + self.prev = None + self.initializers = {} + + def get_parent(self): + return self.parent + + def replace_parent(self, newparent): + self.parent = newparent + + def get_children(self): + raise NotImplementedError( + str(type(self)) + " does not implement get_children()") + + def replace_child(self, old, new): + raise NotImplementedError( + str(type(self)) + " does not implement replace_child()") + + def specialize(self): + """ Some nodes can be simplified after parsing the complete AST and + before actually generating code, i.e., AST_FunCall nodes could be + function calls or array accesses, and we don't really know unless + we know the context of the call. + + This function traverses the AST + and tries to specialize nodes after completing the AST. It should + be called on the top-level AST_Statements node, and a node that + wants to be specialized should return its new instance. If no + specialzation should take place, it should return None. + """ + for c in self.get_children(): + n = c.specialize() + while n is not None: + n.replace_parent(c.get_parent()) + self.replace_child(old=c, new=n) + c = n + n = n.specialize() + + def find_data_node_in_sdfg_state(self, sdfg, state, nodename=None): + if nodename is None: + nodename = self.get_name_in_sdfg(sdfg) + sdfg_state = sdfg.nodes()[state] + for node in sdfg_state.nodes(): + if isinstance(node, dace.graph.nodes.AccessNode): + if node.label == nodename: + return node + + raise ValueError("No AccessNode with name " + nodename + " found.") + + def get_initializers(self, sdfg): + initializers = self.initializers + for c in self.get_children(): + initializers.update(c.get_initializers(sdfg)) + return initializers + + def provide_parents(self, parent): + self.parent = parent + for c in self.get_children(): + c.provide_parents(self) + + def search_vardef_in_scope(self, name): + from .ast_assign import AST_Assign + from .ast_values import AST_Ident + from .ast_loop import AST_ForLoop + current_node = self + + # check if we found the definition: + # * current_node is an AST_Assign with name as lhs or + # * a loop with name as the iterator + if isinstance(current_node, AST_Assign) and \ + isinstance(current_node.lhs, AST_Ident) and \ + (current_node.lhs.get_name() == name): + return current_node.rhs + elif isinstance(current_node, AST_ForLoop) and \ + current_node.var.get_name() == name: + return current_node + + # if current node is inside list of stmts, traverse this list using + # prev, but first find the enclosing AST_Statements + while current_node.get_parent() is not None: + old_current_node = current_node + if isinstance(current_node.get_parent(), AST_Statements): + while current_node.prev is not None: + res = current_node.prev.search_vardef_in_scope(name) + if res is not None: + return res + current_node = current_node.prev + current_node = current_node.get_parent() + res = current_node.search_vardef_in_scope(name) + if res is not None: + return res + + return None + + def defined_variables(self): + # Override this to return the string names of variables defined by an + # AST_Node + return [] + + def get_datanode(self, sdfg, state): + try: + result = self.find_data_node_in_sdfg_state( + sdfg=sdfg, + state=state, + nodename=self.get_name_in_sdfg(sdfg=sdfg)) + except ValueError: + result = sdfg.nodes()[state].add_access( + self.get_name_in_sdfg(sdfg=sdfg)) + return result + + def get_new_tmpvar(self, sdfg): + TEMPVARS_PREFIX = "__tmp_" + maxvar = 0 + for state in range(0, len(sdfg.nodes())): + sdfg_state = sdfg.nodes()[state] + for node in sdfg_state.nodes(): + if isinstance(node, dace.graph.nodes.AccessNode): + m = re.match(TEMPVARS_PREFIX + "(\d+)", node.label) + if m is not None: + if maxvar < int(m.group(1)): + maxvar = int(m.group(1)) + newvar = maxvar + 1 + new_name = TEMPVARS_PREFIX + str(newvar) + return new_name + + def get_name_in_sdfg(self, sdfg): + """ If this node has no name assigned yet, create a new one of the form + `__tmp_X` where `X` is an integer, such that this node does not yet + exist in the given SDFG. + @note: We assume that we create exactly one SDFG from each AST, + otherwise we need to store the hash of the SDFG the name was + created for (would be easy but seems useless at this point). + """ + if self.name is not None: + return self.name + self.name = self.get_new_tmpvar(sdfg) + return self.name + + def generate_code(self, *args): + raise NotImplementedError("Class " + type( + self).__name__ + " does not implement the generate_code method.") + + def shortdesc(self): + ret = str(self) + ret = re.sub(r"\n", " ; ", ret) + return "\"" + ret[0:70] + "\"" + + def print_as_tree(self): + ret = "" + ret += self.shortdesc() + ";\n" + for c in self.get_children(): + ret += self.shortdesc() + "->" + c.shortdesc( + ) + "[label=\"child\", color=\"red\"] ;\n" + ret += c.print_as_tree() + + if self.get_parent() is None: + ret += self.shortdesc( + ) + " -> \"None\" [label=\"parent\", color=\"blue\"];\n" + else: + ret += self.shortdesc() + " -> " + self.get_parent().shortdesc( + ) + "[label=\"parent\", color=\"blue\"];\n" + + if isinstance(self, AST_Statements): + ret += "{ rank=same; " + for c in self.get_children(): + ret += c.shortdesc() + "; " + ret += "}\n" + for c in self.get_children(): + if c.next is not None: + ret += c.shortdesc() + " -> " + c.next.shortdesc( + ) + "[label=\"next\", color=\"green\"]" + if c.prev is not None: + ret += c.shortdesc() + " -> " + c.prev.shortdesc( + ) + "[label=\"prev\", color=\"yellow\"]" + + return ret + + +class AST_Statements(AST_Node): + def __init__(self, context, stmts): + AST_Node.__init__(self, context) + self.statements = stmts + + # we expect stmts to be a list of AST_Node objects + for s in stmts: + if not isinstance(s, AST_Node): + raise ValueError( + "Expected a list of AST_Nodes, but one of the members is: " + + str(s) + " type " + str(type(s))) + + def __repr__(self): + res = ["Statements:"] + for s in self.statements: + res.append(" " + str(s)) + return "\n".join(res) + + def get_children(self): + return self.statements[:] + + def replace_child(self, old, new): + newstmts = [new if x == old else x for x in self.statements] + self.provide_parents(self.get_parent()) + + def append_statement(self, stmt): + if isinstance(stmt, list): + self.statements += stmt + else: + self.statements.append(stmt) + + def provide_parents(self, parent=None): + # Overwrite the AST_Node provide_parents() function + # because we also set next and prev for statements, which + # should be null for most / all AST_Nodes + self.parent = parent + + # fix prev + prev = None + for s in self.statements: + s.prev = prev + prev = s + + # fix next + next = None + for s in reversed(self.statements): + s.next = next + next = s + + for s in self.statements: + s.provide_parents(parent=self) + + def specialize(self): + # If we have an AST_Function() node, pull all statements between that + # and the next AST_EndFunction() into the function. Do that until there + # are no more changes. + rerun = True + while rerun: + rerun = False + stmts = None + func = None + for c in self.get_children(): + from .ast_function import AST_Function, AST_EndFunc + if isinstance(c, AST_Function): + func = c + stmts = [] + elif isinstance(c, AST_EndFunc): + func.set_statements(stmts) + self.statements = [ + x for x in self.statements if x not in stmts + [c] + ] + rerun = True + elif func is not None: + stmts.append(c) + + # Remove NullStatements, they are only useful during parsing + from .ast_nullstmt import AST_NullStmt + self.statements = [ + x for x in self.statements if not isinstance(x, AST_NullStmt) + ] + self.provide_parents(self.parent) + + # Lastly, specialize all children + for c in self.get_children(): + n = c.specialize() + while n is not None: + n.replace_parent(c.get_parent()) + self.replace_child(old=c, new=n) + c = n + n = n.specialize() + + self.provide_parents(self.parent) + + return None + + def generate_code(self, sdfg=None, state=None): + if sdfg is None: + sdfg = dace.SDFG("dacelab", OrderedDict(), {}) + prevstate = None + for s in self.statements: + state = len(sdfg.nodes()) + newstate = dace.SDFGState( + "s" + str(state), sdfg, debuginfo=s.context) + sdfg.add_node(newstate) + last_state = s.generate_code(sdfg, state) + if prevstate is not None: + edge = dace.graph.edges.InterstateEdge() + sdfg.add_edge(prevstate, newstate, edge) + if last_state is None: + prevstate = newstate + else: + prevstate = sdfg.nodes()[last_state] + + return sdfg + else: + raise ValueError( + "Appending statements to an SDFG is not supported.") + + __str__ = __repr__ diff --git a/dace/frontend/octave/ast_nullstmt.py b/dace/frontend/octave/ast_nullstmt.py new file mode 100644 index 0000000000..2b94d0d9ed --- /dev/null +++ b/dace/frontend/octave/ast_nullstmt.py @@ -0,0 +1,52 @@ +from .ast_node import AST_Node + + +class AST_NullStmt(AST_Node): + def __init__(self, context): + AST_Node.__init__(self, context) + + def get_children(self): + return [] + + def replace_child(self, old, new): + raise ValueError("AST_NullStmt has no children") + + def generate_code(self, sdfg, state): + pass + + def __repr__(self): + return "AST_NullStmt()" + + +class AST_EndStmt(AST_Node): + def __init__(self, context): + AST_Node.__init__(self, context) + + def __repr__(self): + return "AST_End()" + + def get_children(self): + return [] + + def replace_child(self, old, new): + raise ValueError("This class does not have children") + + +class AST_Comment(AST_Node): + def __init__(self, context, text): + AST_Node.__init__(self, context) + self.text = text + + def get_children(self): + return [] + + def replace_child(self, old, new): + raise ValueError("AST_Comment has no children") + + def generate_code(self, sdfg, state): + pass + + def __repr__(self): + text = self.text + text = text.encode("unicode_escape").decode("utf-8") + return "AST_Comment(\"" + text + "\")" diff --git a/dace/frontend/octave/ast_range.py b/dace/frontend/octave/ast_range.py new file mode 100644 index 0000000000..00b7c85c04 --- /dev/null +++ b/dace/frontend/octave/ast_range.py @@ -0,0 +1,69 @@ +import dace + +from .ast_node import AST_Node + + +class AST_RangeExpression(AST_Node): + def __init__(self, context, lhs, rhs): + AST_Node.__init__(self, context) + self.lhs = lhs + self.rhs = rhs + + def __repr__(self): + return "AST_RangeExpression(" + str(self.lhs) + ", " + str( + self.rhs) + ")" + + def get_children(self): + L = [self.lhs, self.rhs] + return [x for x in L if x is not None] + + def get_dims(self): + from .ast_values import AST_Constant + if isinstance(self.lhs, AST_Constant) and isinstance( + self.rhs, AST_Constant): + l = self.rhs.get_value() - self.lhs.get_value() + 1 + return [1, l] + else: + print("Dimensionality of " + str(self) + " cannot be inferred") + return [1, 1] + + def get_basetype(self): + return dace.types.float64 + + def replace_child(self, old, new): + if old == self.lhs: + self.lhs = new + return + if old == self.rhs: + self.rhs = new + return + raise ValueError("The child " + str(old) + " is not a child of " + + str(self)) + + def specialize(self): + return None + + def generate_code(self, sdfg, state): + # If lhs and rhs are constant, generate a matrix + from .ast_values import AST_Constant + from .ast_matrix import AST_Matrix_Row, AST_Matrix + if isinstance(self.lhs, AST_Constant) and isinstance( + self.rhs, AST_Constant): + lval = self.lhs.get_value() + rval = self.rhs.get_value() + vals = [ + AST_Constant(self.context, v) + for v in list(range(lval, rval + 1)) + ] + new = AST_Matrix(self.context, + [AST_Matrix_Row(self.context, vals)]) + new.parent = self.parent + new.prev = self.prev + new.next = self.next + new.generate_code(sdfg, state) + else: + raise NotImplementedError( + "Code generation for Range with non-constant bounds not " + "implemented") + + __str__ = __repr__ diff --git a/dace/frontend/octave/ast_values.py b/dace/frontend/octave/ast_values.py new file mode 100644 index 0000000000..0831409eb4 --- /dev/null +++ b/dace/frontend/octave/ast_values.py @@ -0,0 +1,115 @@ +import dace + +from .ast_node import AST_Node + + +class AST_Ident(AST_Node): + def __init__(self, context, value): + AST_Node.__init__(self, context) + if isinstance(value, str): + self.value = value + else: + raise ValueError("Expected str, got " + str(type(value))) + + def __repr__(self): + return "AST_Ident(" + str(self.value) + ")" + + def get_name(self): + return self.value + + def is_constant(self): + return False + + def get_name_in_sdfg(self, sdfg): + return self.value + + def get_children(self): + return [] + + def replace_child(self, old, new): + raise ValueError("This node does not have children!") + + def generate_code(self, sdfg, state): + # An identifier never generates code + pass + + def get_dims(self): + from .ast_loop import AST_ForLoop + """ Check in the scope if this is defined and return the dims of the + corresponding SDFG access node it currently maps to. """ + vardef = self.search_vardef_in_scope(self.value) + if vardef is None: + raise ValueError("Request for dims of identifier " + self.value + + " which is not defined in the current scope") + elif isinstance(vardef, AST_ForLoop): + dims = vardef.initializer.get_dims()[:1] + return dims + else: + return vardef.get_dims() + + def specialize(self): + pass + + def get_propagated_value(self): + vardef = self.search_vardef_in_scope(self.get_name()) + if isinstance(vardef, AST_Constant): + return vardef + return None + + def get_basetype(self): + """ Check in the scope if this is defined and return the basetype of the + corresponding SDFG access node this currently maps to. """ + bt = self.search_vardef_in_scope(self.value).get_basetype() + if bt is None: + raise ValueError("Request for basetype of identifier " + + self.value + + " which is not defined in the current scope") + else: + return bt + + __str__ = __repr__ + + +class AST_Constant(AST_Node): + def __init__(self, context, value): + AST_Node.__init__(self, context) + self.value = value + + def __repr__(self): + return "AST_Constant(" + str(self.value) + ")" + + def get_value(self): + return self.value + + def get_dims(self): + return [1] + + def get_basetype(self): + return dace.types.float64 + + def generate_code(self, sdfg, state): + dims = self.get_dims() + name = self.get_name_in_sdfg(sdfg) + basetype = dace.types.float64 + if name not in sdfg.arrays: + sdfg.add_transient(name, dims, basetype, debuginfo=self.context) + trans = sdfg.nodes()[state].add_access(name) + code = "out = " + str(self.get_value()) + ";" + tasklet = sdfg.nodes()[state].add_tasklet('init', {}, {'out'}, code, + dace.Language.CPP) + sdfg.nodes()[state].add_edge( + tasklet, 'out', trans, None, + dace.memlet.Memlet.from_array(trans.data, trans.desc(sdfg))) + print("The result of expr " + str(self) + " will be stored in " + + str(name)) + + def get_children(self): + return [] + + def is_constant(self): + return True + + def replace_child(self, old, new): + raise ValueError("This node does not have children!") + + __str__ = __repr__ diff --git a/dace/frontend/octave/lexer.py b/dace/frontend/octave/lexer.py new file mode 100644 index 0000000000..6c9297cfd8 --- /dev/null +++ b/dace/frontend/octave/lexer.py @@ -0,0 +1,353 @@ +import sys +import re +import ply.lex as lex +from ply.lex import TOKEN + +tokens = [ + "AND", "ANDAND", "ANDEQ", "BACKSLASH", "COLON", "COMMA", "DIV", "DIVEQ", + "DOT", "DOTDIV", "DOTDIVEQ", "DOTEXP", "DOTMUL", "DOTMULEQ", "END_EXPR", + "END_STMT", "EQ", "EQEQ", "EXP", "EXPEQ", "FIELD", "GE", "GT", "HANDLE", + "IDENT", "LBRACE", "LBRACKET", "LE", "LPAREN", "LT", "MINUS", "MINUSMINUS", + "MINUSEQ", "MUL", "MULEQ", "NE", "NEG", "NUMBER", "OR", "OREQ", "OROR", + "PLUS", "PLUSEQ", "PLUSPLUS", "RBRACE", "RBRACKET", "RPAREN", "SEMI", + "STRING", "TRANSPOSE", "ERROR_STMT", "COMMENT", "END_FUNCTION", + "END_UNEXPECTED", "POW", "CLASSDEF" +] + +reserved = { + "break": "BREAK", + "case": "CASE", + "catch": "CATCH", + "continue": "CONTINUE", + "else": "ELSE", + "elseif": "ELSEIF", + "end_unwind_protect": "END_UNWIND_PROTECT", + "for": "FOR", + "function": "FUNCTION", + "global": "GLOBAL", + "if": "IF", + "otherwise": "OTHERWISE", + "persistent": "PERSISTENT", + "return": "RETURN", + "switch": "SWITCH", + "try": "TRY", + "unwind_protect": "UNWIND_PROTECT", + "unwind_protect_cleanup": "UNWIND_PROTECT_CLEANUP", + "while": "WHILE", +} +tokens += list(reserved.values()) + + +def new(): + t_AND = r"\&" + t_ANDAND = r"\&\&" + t_ANDEQ = r"\&=" + t_BACKSLASH = r"\\" + t_COLON = r":" + t_DIV = r"\/" + t_DIVEQ = r"\/=" + t_DOT = r"\." + t_DOTDIV = r"\./" + t_DOTDIVEQ = r"\./=" + t_DOTEXP = r"\.\^" + t_DOTMUL = r"\.\*" + t_DOTMULEQ = r"\.\*=" + t_EQ = r"=" + t_EQEQ = r"==" + t_EXP = r"\^" + t_EXPEQ = r"\^=" + t_GE = r">=" + t_GT = r"\>" + t_HANDLE = r"\@" + t_LE = r"<=" + t_LT = r"\<" + t_MINUS = r"\-" + t_MINUSEQ = r"\-=" + t_MINUSMINUS = r"\--" + t_MUL = r"\*" + t_POW = r"\*\*" + t_MULEQ = r"\*=" + t_NE = r"(~=)|(!=)" + t_NEG = r"\~|\!" + t_OR = r"\|" + t_OREQ = r"\|=" + t_OROR = r"\|\|" + t_PLUS = r"\+" + t_PLUSEQ = r"\+=" + t_PLUSPLUS = r"\+\+" + + states = (("matrix", "inclusive"), ("afterkeyword", "exclusive")) + + states = (("matrix", "inclusive"), ("afterkeyword", "exclusive")) + + ws = r"(\s|\.\.\..*\n|\\\n)" + #ws = r"(\s|(\#|(%[^!])).*\n|\.\.\..*\n|\\\n)" + ws1 = ws + "+" + ws0 = ws + "*" + ms = r"'([^']|(''))*'" + os = r'"([^"\a\b\f\r\t\0\v\n\\]|(\\[abfn0vtr\"\n\\])|(""))*"' + mos = "(%s)|(%s)" % (os, ms) + id = r"[a-zA-Z_][a-zA-Z_0-9]*" + + def unescape(s): + if s[0] == "'": + return s[1:-1].replace("''", "'") + else: + try: + return s[1:-1].decode("string_escape") + except: + return s[1:-1] + + @TOKEN(mos) + def t_afterkeyword_STRING(t): + t.value = unescape(t.value) + t.lexer.begin("INITIAL") + return t + + def t_afterkeyword_error(t): + t_error(t) + + # A quote, immediately following any of: (1) an alphanumeric + # charater, (2) right bracket, parenthesis or brace, + # or (3) another TRANSPOSE, is a TRANSPOSE. Otherwise, it starts a + # string. The order of the rules for TRANSPOSE (first) and STRING + # (second) is important. Luckily, if the quote is separated from + # the term by line continuation (...), matlab starts a string, so + # the above rule still holds. + + def t_TRANSPOSE(t): + r"(?<=\w|\]|\)|\})((\.')|')+" + # <---context ---><-quotes-> + # Let the parser figure out what that mix of quotes and + # dot-quotes, which is kept in t.value, really means. + return t + + @TOKEN(mos) + def t_STRING(t): + t.value = unescape(t.value) + return t + + @TOKEN(r"(\.%s)?%s" % (ws0, id)) + def t_IDENT(t): + if t.value == "parfor": + t.value = "for" + if t.value == "classdef": + raise_exception(SyntaxError, "Not implemented: %s" % t.value, + t.lexer) + t.lexer.lineno += t.value.count("\n") + if t.value[0] == ".": + # Reserved words are not reserved + # when used as fields. So return=1 + # is illegal, but foo.return=1 is fine. + t.type = "FIELD" + return t + if (t.value == "end" and (t.lexer.parens > 0 or t.lexer.brackets > 0 + or t.lexer.braces > 0)): + t.type = "END_EXPR" + return t + if t.value in ("end", "endif", "endfunction", "endwhile", "endfor", + "endswitch", "end_try_catch"): + keyword = t.lexer.stack.pop() # if,while,etc. + if keyword == "function": + t.type = "END_FUNCTION" + else: + t.type = "END_STMT" + return t + else: + t.type = reserved.get(t.value, "IDENT") + if t.value in ("if", "function", "while", "for", "switch", "try"): + # Lexer stack may contain only these + # six words, ever, because there is + # one place to push -- here + t.lexer.stack.append(t.value) + if (t.type != "IDENT" and t.lexer.lexdata[t.lexer.lexpos] == "'"): + t.lexer.begin("afterkeyword") + return t + + def t_LPAREN(t): + r"\(" + t.lexer.parens += 1 + return t + + def t_RPAREN(t): + r"\)" + t.lexer.parens -= 1 + return t + + @TOKEN(ws0 + r"\]") + def t_RBRACKET(t): # compare w t_LBRACKET + t.lexer.lineno += t.value.count("\n") + t.lexer.brackets -= 1 + if t.lexer.brackets + t.lexer.braces == 0: + t.lexer.begin("INITIAL") + return t + + @TOKEN(r"\[" + ws0) + def t_LBRACKET(t): # compare w t_SEMI + t.lexer.lineno += t.value.count("\n") + t.lexer.brackets += 1 + if t.lexer.brackets + t.lexer.braces == 1: + t.lexer.begin("matrix") + return t + + # maybe we need a dedicated CELLARRAY state + @TOKEN(ws0 + r"\}") + def t_RBRACE(t): + t.lexer.lineno += t.value.count("\n") + t.lexer.braces -= 1 + if t.lexer.braces + t.lexer.brackets == 0: + t.lexer.begin("INITIAL") + return t + + @TOKEN(r"\{" + ws0) + def t_LBRACE(t): + t.lexer.lineno += t.value.count("\n") + t.lexer.braces += 1 + if t.lexer.brackets + t.lexer.braces == 1: + t.lexer.begin("matrix") + return t + + @TOKEN(r"," + ws0) + def t_COMMA(t): # eating spaces is important inside brackets + t.lexer.lineno += t.value.count("\n") + if (t.lexer.brackets == 0 and t.lexer.parens == 0 + and t.lexer.braces == 0): + t.type = "SEMI" + return t + return t + + @TOKEN(r"\;" + ws0) + def t_SEMI(t): + t.lexer.lineno += t.value.count("\n") + # if t.lexer.brackets or t.lexer.braces > 0: + # t.type = "CONCAT" + return t + + def t_NUMBER(t): + r"(0x[0-9A-Fa-f]+)|((\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?[ij]?)" + # <-------------> <------------------><-------------> + # int,oct,hex float exp + if t.value[-1] == 'i': + t.value = t.value[:-1] + 'j' + t.value = eval(t.value) + return t + + def t_NEWLINE(t): + r'\n+' + t.lexer.lineno += len(t.value) + if not t.lexer.parens and not t.lexer.braces: + t.value = ";" + t.type = "SEMI" + return t + + def t_ERROR_STMT(t): + r"%!(error|warning|test).*\n" + t.lexer.lineno += 1 + + # Keep multiline comments + def t_COMMENT(t): + r"(^[ \t]*[%#][^!\n].*\n)+" + t.lexer.lineno += t.value.count("\n") + t.type = "COMMENT" + return t + + # Drop end-of-line comments + def t_comment(t): + r"(%|\#)!?" + if t.value[-1] != "!": + t.lexer.lexpos = t.lexer.lexdata.find("\n", t.lexer.lexpos) + + @TOKEN(r"(?<=\w)" + ws1 + r"(?=\()") + def t_matrix_BAR(t): + # Consume whitespace that follows end of name + # and is followed a left parenthesis. This properly handles + # a space between a func name and the arguments. + pass + + tend = r"(?<=[])}'\".]|\w)" + tbeg = r"(?=[-+]?([[({'\"]|\w|\.\d))" + + @TOKEN(tend + ws1 + tbeg) + def t_matrix_FOO(t): + # In matrix state, consume whitespace separating two + # terms and return a fake COMMA token. This allows + # parsing [1 2 3] as if it was [1,2,3]. Handle + # with care: [x + y] vs [x +y] + # + # A term T is + # (a) a name or a number + # (b) literal string using single or doble quote + # (c) (T) or [T] or {T} or T' or +T or -T + # + # Terms end with + # (1) an alphanumeric charater \w + # (2) single quote (in octave also double-quote) + # (3) right parenthesis, bracket, or brace + # (4) a dot (after a number, such as 3. + # + # The pattern for whitespace accounts for ellipsis as a + # whitespace, and for the trailing whitespace. + # + # Terms start with + # (1) an alphanumeric character + # (2) a single or double quote, + # (3) left parenthesis, bracket, or brace and finally + # (4) a dot before a digit, such as .3 . + + # TODO: What about curly brackets? + # TODO: What about dot followed by a letter, as in field? + # [foo .bar] + + t.lexer.lineno += t.value.count("\n") + t.type = "COMMA" + return t + + def t_ELLIPSIS(t): + r"\.\.\..*\n" + t.lexer.lineno += 1 + pass + + def t_SPACES(t): + r"(\\\n|[ \t\r])+" + pass + + def t_error(t): + raise_exception(SyntaxError, ('Unexpected "%s" (lexer)' % t.value), + t.lexer) + + lexer = lex.lex(reflags=re.MULTILINE) + lexer.brackets = 0 # count open square brackets + lexer.parens = 0 # count open parentheses + lexer.braces = 0 # count open curly braces + lexer.stack = [] + return lexer + + +def raise_exception(error_type, message, my_lexer): + startpos = 1 + my_lexer.lexdata.rfind("\n", 0, my_lexer.lexpos) + endpos = my_lexer.lexdata.find("\n", startpos) + raise error_type( + message, ("inputfile", my_lexer.lineno, 1 + my_lexer.lexpos - startpos, + my_lexer.lexdata[startpos:endpos])) + + +def main(): + lexer = new() + line = "" + while 1: + try: + line += raw_input("=>> ").decode("string_escape") + print(len(line), [c for c in line]) + except EOFError: + reload(sys.modules["lexer.py"]) + lexer.input(line) + print(list(tok for tok in lexer)) + line = "" + + +if __name__ == "__main__": + lexer = new() + buf = open(sys.argv[1]).read() + lexer.input(buf) + for tok in lexer: + print(tok) diff --git a/dace/frontend/octave/parse.py b/dace/frontend/octave/parse.py new file mode 100644 index 0000000000..a5a1ccbaba --- /dev/null +++ b/dace/frontend/octave/parse.py @@ -0,0 +1,689 @@ +import sys +from ply import yacc +from . import lexer +import copy +import dace + +from .ast_node import AST_Node, AST_Statements +from .ast_values import AST_Ident, AST_Constant +from .ast_expression import AST_BinExpression, AST_UnaryExpression +from .ast_matrix import AST_Matrix_Row, AST_Matrix, AST_Transpose +from .ast_assign import AST_Assign +from .ast_function import AST_Argument, AST_BuiltInFunCall, AST_FunCall, AST_Function, AST_EndFunc +from .ast_range import AST_RangeExpression +from .ast_loop import AST_ForLoop +from .ast_nullstmt import AST_NullStmt, AST_Comment, AST_EndStmt + +tokens = lexer.tokens + +precedence = ( + ("right", "COMMA"), + ("right", "DOTDIVEQ", "DOTMULEQ", "EQ", "EXPEQ", "MULEQ", "MINUSEQ", + "DIVEQ", "PLUSEQ", "OREQ", "ANDEQ"), + ("nonassoc", "HANDLE"), + ("left", "COLON"), + ("left", "ANDAND", "OROR"), + ("left", "EQEQ", "NE", "GE", "LE", "GT", "LT"), + ("left", "OR", "AND"), + ("left", "PLUS", "MINUS"), + ("left", "MUL", "DIV", "DOTMUL", "DOTDIV", "BACKSLASH"), + ("right", "UMINUS", "NEG"), + ("right", "TRANSPOSE"), + ("right", "EXP", "DOTEXP", "POW"), + ("nonassoc", "LPAREN", "RPAREN", "RBRACE", "LBRACE"), + ("left", "FIELD", "DOT", "PLUSPLUS", "MINUSMINUS"), +) + + +def p_top(p): + """ + top : + | top stmt + """ + + if len(p) == 1: + retval = AST_Statements(None, []) + p[0] = retval + else: + retval = copy.deepcopy(p[1]) + retval.append_statement(p[2]) + p[0] = retval + + +def p_end(p): + """ + top : top END_STMT + """ + retval = copy.deepcopy(p[1]) + retval.append_statement(AST_EndStmt(None)) + p[0] = retval + + +def p_end_function(p): + """ + top : top END_FUNCTION + """ + retval = copy.deepcopy(p[1]) + retval.append_statement(AST_EndFunc(None)) + p[0] = retval + + +def p_arg1(p): + """ + arg1 : IDENT + """ + startl, endl = p.linespan(1) + startc, endc = p.lexspan(1) + di = dace.types.DebugInfo(startl, startc, endl, endc) + p[0] = AST_Ident(di, p[1]) + + +def p_arg2(p): + """ + arg1 : NUMBER + | STRING + """ + startl, endl = p.linespan(1) + startc, endc = p.lexspan(1) + di = dace.types.DebugInfo(startl, startc, endl, endc) + p[0] = AST_Constant(di, p[1]) + + +def p_global(p): + """ + arg1 : GLOBAL + """ + raise NotImplementedError("global not implemented") + + +def p_arg_list(p): + """ + arg_list : ident_init_opt + | arg_list COMMA ident_init_opt + """ + if len(p) == 2: + p[0] = [p[1]] + else: + p[0] = p[1] + [p[3]] + + +def p_args(p): + """ + args : arg1 + | args arg1 + """ + raise NotImplementedError("args not implemented") + + +def p_break_stmt(p): + """ break_stmt : BREAK SEMI """ + raise NotImplementedError("break not implemented") + + +def p_case_list(p): + """ + case_list : + | CASE expr sep stmt_list_opt case_list + | CASE expr error stmt_list_opt case_list + | OTHERWISE stmt_list + """ + raise NotImplementedError("case not implemented") + + +def p_cellarray(p): + """ + cellarray : LBRACE RBRACE + | LBRACE matrix_row RBRACE + | LBRACE matrix_row SEMI RBRACE + """ + startl, endl = p.linespan(0) + startc, endc = p.lexspan(0) + di = dace.types.DebugInfo(startl, startc, endl, endc) + + if len(p) == 3: + p[0] = AST_Matrix(di, []) + else: + p[0] = AST_Matrix(di, p[2]) + + +def p_cellarray_2(p): + """ + cellarray : LBRACE expr_list RBRACE + """ + p[0] = AST_Matrix(di, [AST_Matrix_Row(p[2])]) + + +def p_cellarrayref(p): + """expr : expr LBRACE expr_list RBRACE + | expr LBRACE RBRACE + """ + raise NotImplementedError("cellarrayref not implemented") + + +def p_command(p): + """ + command : ident args SEMI + """ + raise NotImplementedError("commands not implemented") + + +#################### + + +def p_comment_stmt(p): + """ + comment_stmt : COMMENT + """ + di = None + p[0] = AST_Comment(di, p[1]) + + +def p_concat_list1(p): + """ + matrix_row : expr_list SEMI expr_list + """ + + startl, endl = p.linespan(1) + startc, endc = p.lexspan(1) + di1 = dace.types.DebugInfo(startl, startc, endl, endc) + + startl, endl = p.linespan(3) + startc, endc = p.lexspan(3) + di3 = dace.types.DebugInfo(startl, startc, endl, endc) + + p[0] = [AST_Matrix_Row(di1, p[1]), AST_Matrix_Row(di3, p[3])] + + +def p_concat_list2(p): + """ + matrix_row : matrix_row SEMI expr_list + """ + startl, endl = p.linespan(3) + startc, endc = p.lexspan(3) + di3 = dace.types.DebugInfo(startl, startc, endl, endc) + + p[0] = p[1] + [AST_Matrix_Row(di3, p[3])] + + +def p_continue_stmt(p): + "continue_stmt : CONTINUE SEMI" + raise NotImplementedError("continue needs to be implemented") + + +def p_elseif_stmt(p): + """ + elseif_stmt : + | ELSE stmt_list_opt + | ELSEIF expr sep stmt_list_opt elseif_stmt + | ELSEIF LPAREN expr RPAREN stmt_list_opt elseif_stmt + """ + raise NotImplementedError("elseif needs to be implemented") + + +def p_error_stmt(p): + """ + error_stmt : ERROR_STMT SEMI + """ + raise NotImplementedError("error stmt") + + +def p_expr(p): + """expr : ident + | end + | number + | string + | colon + | NEG + | matrix + | cellarray + | expr2 + | expr1 + | lambda_expr + """ + p[0] = p[1] + + +def p_expr_2(p): + """expr : expr PLUSPLUS + | expr MINUSMINUS + """ + startl, endl = p.linespan(2) + startc, endc = p.lexspan(2) + di2 = dace.types.DebugInfo(startl, startc, endl, endc) + + p[0] = AST_UnaryExpression(di2, p[1], p[2], "post") + + +def p_expr1(p): + """expr1 : MINUS expr %prec UMINUS + | PLUS expr %prec UMINUS + | NEG expr + | HANDLE ident + | PLUSPLUS ident + | MINUSMINUS ident + """ + startl, endl = p.linespan(1) + startc, endc = p.lexspan(1) + di1 = dace.types.DebugInfo(startl, startc, endl, endc) + + p[0] = AST_UnaryExpression(di1, p[2], p[1], "pre") + + +def p_expr2(p): + """expr2 : expr AND expr + | expr ANDAND expr + | expr BACKSLASH expr + | expr COLON expr + | expr DIV expr + | expr DOT expr + | expr DOTDIV expr + | expr DOTDIVEQ expr + | expr DOTEXP expr + | expr DOTMUL expr + | expr DOTMULEQ expr + | expr EQEQ expr + | expr POW expr + | expr EXP expr + | expr EXPEQ expr + | expr GE expr + | expr GT expr + | expr LE expr + | expr LT expr + | expr MINUS expr + | expr MUL expr + | expr NE expr + | expr OR expr + | expr OROR expr + | expr PLUS expr + | expr EQ expr + | expr MULEQ expr + | expr DIVEQ expr + | expr MINUSEQ expr + | expr PLUSEQ expr + | expr OREQ expr + | expr ANDEQ expr + """ + startl, endl = p.linespan(2) + startc, endc = p.lexspan(2) + di2 = dace.types.DebugInfo(startl, startc, endl, endc) + + if p[2] == "=": + p[0] = AST_Assign(di2, p[1], p[3], p[2]) + elif p[2] == ":": + p[0] = AST_RangeExpression(di2, p[1], p[3]) + else: + p[0] = AST_BinExpression(di2, p[1], p[3], p[2]) + + +def p_expr_colon(p): + """ colon : COLON """ + startl, endl = p.linespan(1) + startc, endc = p.lexspan(1) + di1 = dace.types.DebugInfo(startl, startc, endl, endc) + + p[0] = AST_RangeExpression(di1, None, None) + + +def p_expr_end(p): + """ end : END_EXPR """ + raise NotImplementedError("end expression needs to be implemented") + + +def p_expr_ident(p): + """ ident : IDENT """ + startl, endl = p.linespan(1) + startc, endc = p.lexspan(1) + di1 = dace.types.DebugInfo(startl, startc, endl, endc) + + p[0] = AST_Ident(di1, p[1]) + + +def p_ident_init_opt(p): + """ + ident_init_opt : NEG + | ident + | ident EQ expr + """ + if len(p) == 1: + raise NotImplementedError("default args need to be implemented") + if len(p) == 2: + p[0] = p[1] + else: + raise NotImplementedError("default args need to be implemented") + + +def p_expr_list(p): + """ + expr_list : exprs + | exprs COMMA + """ + p[0] = p[1] + + +def p_expr_number(p): + """ number : NUMBER """ + startl, endl = p.linespan(1) + startc, endc = p.lexspan(1) + di1 = dace.types.DebugInfo(startl, startc, endl, endc) + + p[0] = AST_Constant(di1, p[1]) + + +def p_expr_stmt(p): + """ + expr_stmt : expr_list SEMI + """ + p[0] = p[1] + + +def p_expr_string(p): + """ string : STRING """ + startl, endl = p.linespan(1) + startc, endc = p.lexspan(1) + di1 = dace.types.DebugInfo(startl, startc, endl, endc) + + p[0] = AST_Constant(di1, p[1]) + + +def p_exprs(p): + """ + exprs : expr + | exprs COMMA expr + """ + if len(p) == 2: + p[0] = [p[1]] + elif len(p) == 4: + p[0] = p[1] + p[0].append(p[3]) + + +def p_field_expr(p): + """ + expr : expr FIELD + """ + raise NotImplementedError("field expressions needs to be implemented") + + +def p_foo_stmt(p): + """ foo_stmt : expr OROR expr SEMI """ + raise NotImplementedError("foo_stmt needs to be implemented") + + +def p_for_stmt(p): + """ + for_stmt : FOR ident EQ expr SEMI stmt_list END_STMT + | FOR LPAREN ident EQ expr RPAREN SEMI stmt_list END_STMT + | FOR matrix EQ expr SEMI stmt_list END_STMT + """ + di = None + if len(p) == 8: + p[0] = AST_ForLoop(di, p[2], p[4], AST_Statements(di, p[6])) + else: + p[0] = AST_ForLoop(di, p[3], p[5], AST_Statements(di, p[8])) + + +def p_func_stmt(p): + """func_stmt : FUNCTION ident lambda_args SEMI + | FUNCTION ret EQ ident lambda_args SEMI + """ + di = None + if len(p) == 5: + p[0] = AST_Function(di, p[2], args=p[3], retvals=[]) + else: + p[0] = AST_Function(di, p[4], args=p[5], retvals=p[2]) + + +def p_funcall_expr(p): + """expr : expr LPAREN expr_list RPAREN + | expr LPAREN RPAREN + """ + startl, endl = p.linespan(1) + startc, endc = p.lexspan(1) + di1 = dace.types.DebugInfo(startl, startc, endl, endc) + + if len(p) == 4: + p[0] = AST_FunCall(di1, p[1], []) + else: + p[0] = AST_FunCall(di1, p[1], p[3]) + + +def p_global_list(p): + """global_list : ident + | global_list ident + """ + raise NotImplementedError("globals need to be implemented") + + +def p_global_stmt(p): + """ + global_stmt : GLOBAL global_list SEMI + | GLOBAL ident EQ expr SEMI + """ + raise NotImplementedError("globals need to be implemented") + + +def p_if_stmt(p): + """ + if_stmt : IF expr sep stmt_list_opt elseif_stmt END_STMT + | IF LPAREN expr RPAREN stmt_list_opt elseif_stmt END_STMT + """ + raise NotImplementedError("If/else needs to be implemented") + + +def p_lambda_args(p): + """lambda_args : LPAREN RPAREN + | LPAREN arg_list RPAREN + """ + if len(p) == 3: + p[0] = [] + else: + p[0] = p[2] + + +def p_lambda_expr(p): + """lambda_expr : HANDLE lambda_args expr + """ + raise NotImplementedError("lambda needs to be implemented") + + +def p_matrix(p): + """matrix : LBRACKET RBRACKET + | LBRACKET matrix_row RBRACKET + | LBRACKET matrix_row SEMI RBRACKET + """ + startl, endl = p.linespan(0) + startc, endc = p.lexspan(0) + di0 = dace.types.DebugInfo(startl, startc, endl, endc) + + if len(p) == 3: + p[0] = AST_Matrix(di0, []) + else: + p[0] = AST_Matrix(di0, p[2]) + + +def p_matrix_2(p): + """matrix : LBRACKET expr_list RBRACKET + | LBRACKET expr_list SEMI RBRACKET + """ + startl, endl = p.linespan(0) + startc, endc = p.lexspan(0) + di0 = dace.types.DebugInfo(startl, startc, endl, endc) + startl, endl = p.linespan(2) + startc, endc = p.lexspan(2) + di2 = dace.types.DebugInfo(startl, startc, endl, endc) + + p[0] = AST_Matrix(di0, [AST_Matrix_Row(di2, p[2])]) + + +def p_null_stmt(p): + """ + null_stmt : SEMI + | COMMA + """ + di = None + p[0] = AST_NullStmt(di) + + +def p_parens_expr(p): + """ + expr : LPAREN expr RPAREN + """ + p[0] = p[2] + + +def p_persistent_stmt(p): + """ + persistent_stmt : PERSISTENT global_list SEMI + | PERSISTENT ident EQ expr SEMI + """ + raise NotImplementedError("persistent needs to be implemented") + + +def p_ret(p): + """ + ret : ident + | LBRACKET RBRACKET + | LBRACKET expr_list RBRACKET + """ + if len(p) == 2: + p[0] = [p[1]] + elif len(p) == 3: + p[0] = [] + else: + p[0] = p[2] + + +def p_return_stmt(p): + """ return_stmt : RETURN SEMI """ + raise NotImplementedError("return needs to be implemented") + + +def p_semi_opt(p): + """ + semi_opt : + | semi_opt SEMI + | semi_opt COMMA + """ + p[0] = AST_NullStmt(None) + + +def p_separator(p): + """ + sep : COMMA + | SEMI + """ + p[0] = p[1] + + +def p_stmt(p): + """ + stmt : continue_stmt + | comment_stmt + | func_stmt + | break_stmt + | expr_stmt + | global_stmt + | persistent_stmt + | error_stmt + | command + | for_stmt + | if_stmt + | null_stmt + | return_stmt + | switch_stmt + | try_catch + | while_stmt + | foo_stmt + | unwind + """ + # END_STMT is intentionally left out + p[0] = copy.deepcopy(p[1]) + + +def p_stmt_list(p): + """ + stmt_list : stmt + | stmt_list stmt + """ + if len(p) == 2: + if p[1] is None: + p[0] = [] + if isinstance(p[1], list): + p[0] = copy.deepcopy(p[1]) + elif len(p) == 3: + p[0] = copy.deepcopy(p[1]) + if p[2] is not None: + if isinstance(p[2], list): + p[0] = p[0] + p[2] + else: + p[0].append(p[2]) + else: + assert 0 + + +def p_stmt_list_opt(p): + """ + stmt_list_opt : + | stmt_list + """ + if len(p) == 1: + p[0] = [] + else: + p[0] = p[1] + + +def p_switch_stmt(p): + """ + switch_stmt : SWITCH expr semi_opt case_list END_STMT + """ + raise NotImplementedError("switch needs to be implemented") + + +def p_transpose_expr(p): + # p[2] contains the exact combination of plain and conjugate + # transpose operators, such as "'.''.''''". + """ expr : expr TRANSPOSE """ + startl, endl = p.linespan(2) + startc, endc = p.lexspan(2) + di2 = dace.types.DebugInfo(startl, startc, endl, endc) + + p[0] = AST_Transpose(di2, p[1], p[2]) + + +def p_try_catch(p): + """ + try_catch : TRY stmt_list CATCH stmt_list END_STMT + """ + raise NotImplementedError("try/catch needs to be implemented") + + +def p_unwind(p): + """ + unwind : UNWIND_PROTECT stmt_list UNWIND_PROTECT_CLEANUP stmt_list END_UNWIND_PROTECT + """ + raise NotImplementedError("unwind needs to be implemented") + + +def p_while_stmt(p): + """ + while_stmt : WHILE expr SEMI stmt_list END_STMT + """ + raise NotImplementedError("while needs to be implemented") + + +def p_error(p): + raise ValueError("Unexpected EOF") + + +parser = yacc.yacc(start="top") + + +def parse(buf, debug=False): + new_lexer = lexer.new() + p = parser.parse(buf, tracking=1, debug=debug, lexer=new_lexer) + return p + + +if __name__ == "__main__": + buf = open(sys.argv[1]).read() + p = parse(buf, debug=False) diff --git a/dace/frontend/operations.py b/dace/frontend/operations.py new file mode 100644 index 0000000000..e5d140f315 --- /dev/null +++ b/dace/frontend/operations.py @@ -0,0 +1,175 @@ +from __future__ import print_function +from functools import partial + +from timeit import default_timer as timer +import ast +import numpy as np +import sympy +import os +import sys + +from dace import types +from dace.config import Config + + +def timethis(program, title, flop_count, f, *args, **kwargs): + """ Runs a function multiple (`DACE_treps`) times, logs the running times + to a file, and prints the median time (with FLOPs if given). + @param program: The title of the measurement. + @param title: A sub-title of the measurement. + @param flop_count: Number of floating point operations in `program`. + If greater than zero, produces a median FLOPS + report. + @param f: The function to measure. + @param args: Arguments to invoke the function with. + @param kwargs: Keyword arguments to invoke the function with. + @return: Latest return value of the function. + """ + + start = timer() + REPS = int(Config.get('treps')) + times = [start] * (REPS + 1) + ret = None + for i in range(REPS): + # Call function + ret = f(*args, **kwargs) + times[i + 1] = timer() + + diffs = np.array([(times[i] - times[i - 1]) for i in range(1, REPS + 1)]) + + problem_size = sys.argv[1] if len(sys.argv) >= 2 else 0 + + if not os.path.isfile('results.log'): + with open('results.log', 'w') as f: + f.write('Program\tOptimization\tProblem_Size\tRuntime_sec\n') + + with open('results.log', 'w') as f: + for d in diffs: + f.write('%s\t%s\t%s\t%.8f\n' % (program, title, problem_size, d)) + + if flop_count > 0: + gflops_arr = (flop_count / diffs) * 1e-9 + time_secs = np.median(diffs) + GFLOPs = (flop_count / time_secs) * 1e-9 + print(title, GFLOPs, 'GFLOP/s (', time_secs * 1000, 'ms)') + else: + time_secs = np.median(diffs) + print(title, time_secs * 1000, 'ms') + + return ret + + +def detect_reduction_type(wcr_str): + """ Inspects a lambda function and tries to determine if it's one of the + built-in reductions that frameworks such as MPI can provide. + + @param wcr_str: A Python string representation of the lambda function. + @return: types.ReductionType if detected, types.ReductionType.Custom + if not detected, or None if no reduction is found. + """ + if wcr_str == '' or wcr_str is None: + return None + + # Get lambda function from string + wcr = eval(wcr_str) + wcr_ast = ast.parse(wcr_str).body[0].value.body + + # Run function through symbolic math engine + a = sympy.Symbol('a') + b = sympy.Symbol('b') + try: + result = wcr(a, b) + except TypeError: # e.g., "Cannot determine truth value of relational" + result = None + + # Check resulting value + if result == sympy.Max(a, b) or (isinstance(wcr_ast, ast.Call) + and isinstance(wcr_ast.func, ast.Name) + and wcr_ast.func.id == 'max'): + return types.ReductionType.Max + elif result == sympy.Min(a, b) or (isinstance(wcr_ast, ast.Call) + and isinstance(wcr_ast.func, ast.Name) + and wcr_ast.func.id == 'min'): + return types.ReductionType.Min + elif result == a + b: + return types.ReductionType.Sum + elif result == a * b: + return types.ReductionType.Product + elif result == a & b: + return types.ReductionType.Bitwise_And + elif result == a | b: + return types.ReductionType.Bitwise_Or + elif result == a ^ b: + return types.ReductionType.Bitwise_Xor + elif isinstance(wcr_ast, ast.BoolOp) and isinstance(wcr_ast.op, ast.And): + return types.ReductionType.Logical_And + elif isinstance(wcr_ast, ast.BoolOp) and isinstance(wcr_ast.op, ast.Or): + return types.ReductionType.Logical_Or + elif (isinstance(wcr_ast, ast.Compare) + and isinstance(wcr_ast.ops[0], ast.NotEq)): + return types.ReductionType.Logical_Xor + + return types.ReductionType.Custom + + +def is_op_commutative(wcr_str): + """ Inspects a custom lambda function and tries to determine whether + it is symbolically commutative (disregarding data type). + @param wcr_str: A string in Python representing a lambda function. + @return: True if commutative, False if not, None if cannot be + determined. + """ + if wcr_str == '' or wcr_str is None: + return None + + # Get lambda function from string + wcr = eval(wcr_str) + + # Run function through symbolic math engine + a = sympy.Symbol('a') + b = sympy.Symbol('b') + try: + aRb = wcr(a, b) + bRa = wcr(b, a) + except TypeError: # e.g., "Cannot determine truth value of relational" + return None + + return aRb == bRa + + +def is_op_associative(wcr_str): + """ Inspects a custom lambda function and tries to determine whether + it is symbolically associative (disregarding data type). + @param wcr_str: A string in Python representing a lambda function. + @return: True if associative, False if not, None if cannot be + determined. + """ + if wcr_str == '' or wcr_str is None: + return None + + # Get lambda function from string + wcr = eval(wcr_str) + + # Run function through symbolic math engine + a = sympy.Symbol('a') + b = sympy.Symbol('b') + c = sympy.Symbol('c') + try: + aRbc = wcr(a, wcr(b, c)) + abRc = wcr(wcr(a, b), c) + except TypeError: # e.g., "Cannot determine truth value of relational" + return None + + return aRbc == abRc + + +def reduce(op, in_array, out_array, axis=None, identity=None): + """ Reduces an array according to an operation `op`, starting with + initial value `identity`, over the given axis (or all axes if none + given), to `out_array`. + + Requires `out_array` with one dimension less than `in_array`, or a + scalar if `axis` is None. + """ + # The function is empty because it is parsed in astparser + return None diff --git a/dace/frontend/python/__init__.py b/dace/frontend/python/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/dace/frontend/python/__init__.py @@ -0,0 +1 @@ + diff --git a/dace/frontend/python/astnodes.py b/dace/frontend/python/astnodes.py new file mode 100644 index 0000000000..0b8122b353 --- /dev/null +++ b/dace/frontend/python/astnodes.py @@ -0,0 +1,196 @@ +""" Support classes for the DaCe Python AST parser. """ + +from collections import OrderedDict +from copy import deepcopy as dcpy + +from dace import data, types +from dace.frontend.python import astutils + + +class _Node(object): + """ SDFG AST node class, generated from the DaCe Python AST parser. """ + + def __init__(self, name, node_ast): + self.name = name + + # Maps: {local variable name: array subscript expression (AST node)} + self.inputs = OrderedDict() + + # Maps: {local variable name: array subscript expression (AST node)} + self.outputs = OrderedDict() + + # All variables in the parent scope + current scope + # Maps: {variable name: value AST node} + self.globals = OrderedDict() + + # All local variables defined in this scope + # Maps: {local variable name: value AST node} + self.locals = OrderedDict() + + # Maps: {transient array name: data.Data} + self.transients = OrderedDict() + + # List of parameter names + self.params = [] + + # Parent _Node object + self.parent = None + + # List of children _Node objects + self.children = [] + + # Is asynchronous + self.is_async = False + + # Node AST + self.ast = node_ast + + def __deepcopy__(self, memo): + n = object.__new__(type(self)) + + n.name = dcpy(self.name) + n.inputs = dcpy(self.inputs) + n.outputs = dcpy(self.outputs) + n.globals = self.globals + n.locals = dcpy(self.locals) + n.transients = dcpy(self.transients) + n.params = dcpy(self.params) + n.parent = None + n.children = [] + n.is_async = dcpy(self.is_async) + + return n + + # Returns the arrays local to this node's context + def arrays(self): + return OrderedDict([(k, v) for k, v in self.globals.items() + if isinstance(v, data.Data)]) + + # Returns all arrays (children included) + def all_arrays(self): + result = self.arrays() + for c in self.children: + result.update(c.all_arrays()) + return result + + def dump(self, indent=0): + print(' ' * indent + self.__class__.__name__ + ': ' + self.name) + for c in self.children: + c.dump(indent + 1) + + +class _ProgramNode(_Node): + """ SDFG AST node class. """ + pass + + +# Dataflow nodes +class _DataFlowNode(_Node): + """ Dataflow AST node superclass. """ + pass + + +class _ScopeNode(_DataFlowNode): + """ Scope (map/consume) AST node superclass. """ + pass + + +class _MapNode(_ScopeNode): + """ Map AST node type. """ + #def __init__(self, name, node_ast, range, ) + pass + + +class _ConsumeNode(_ScopeNode): + """ Consume AST node type. """ + #def __init(self, name, node_ast, stream, ...) + pass + + +class _TaskletNode(_DataFlowNode): + """ Tasklet AST node type. """ + + def __init__(self, + name, + node_ast, + language=types.Language.Python, + global_code=''): + super(_TaskletNode, self).__init__(name, node_ast) + self.language = language + self.extcode = None + self.gcode = global_code + + +class _EmptyTaskletNode(_TaskletNode): + """ Empty Tasklet AST node type. """ + pass + + +class _NestedSDFGNode(_DataFlowNode): + """ Nested SDFG AST node type. """ + + def __init__(self, name, node_ast, sdfg): + super(_NestedSDFGNode, self).__init__(name, node_ast) + self.sdfg = sdfg + + +# Operation nodes +class _ReduceNode(_DataFlowNode): + """ Reduce AST node type. """ + pass + + +# Control flow nodes +class _ControlFlowNode(_Node): + """ Control-flow AST node superclass. """ + pass + + +class _IterateNode(_ControlFlowNode): + """ Iteration (for-loop) AST node type. """ + pass + + +class _LoopNode(_ControlFlowNode): + """ Loop (while-loop) AST node type. """ + pass + + +class _ConditionalNode(_ControlFlowNode): + """ Conditional (if/else) AST node superclass. """ + pass + + +class _IfNode(_ConditionalNode): + """ If conditional AST node type. """ + pass + + +class _ElseNode(_ConditionalNode): + """ Else conditional AST node type. """ + pass + + +class _Memlet(object): + """ AST Memlet type. Becomes an SDFG edge. """ + + def __init__(self, data, data_name, attribute, num_accesses, + write_conflict_resolution, wcr_identity, subset, + vector_length, local_name, ast, array_dependencies): + self.data = data # type: Data + self.dataname = data_name # type: str + self.attribute = attribute # type: str + self.num_accesses = num_accesses # type: sympy math + self.wcr = write_conflict_resolution # type: ast._Lambda + self.wcr_identity = wcr_identity # type: memlet type or None + self.subset = subset # type: subsets.Subset + self.veclen = vector_length # type: int (in elements, default 1) + self.local_name = local_name # type: str + self.ast = ast # type: ast._AST + self.otherdeps = array_dependencies # type: dict(str, data.Data) + + def wcr_name(self): + label = astutils.unparse(self.wcr.body) + if self.wcr_identity is not None: + label += ', id: ' + str(self.wcr_identity) + return label diff --git a/dace/frontend/python/astparser.py b/dace/frontend/python/astparser.py new file mode 100644 index 0000000000..8ca4b8ada4 --- /dev/null +++ b/dace/frontend/python/astparser.py @@ -0,0 +1,1667 @@ +from __future__ import print_function +import ast +import astunparse +from collections import OrderedDict +import copy +from functools import wraps +import inspect + +from dace import data, subsets, symbolic, types +from dace.config import Config +from dace.frontend.python import astnodes, astutils +from dace.frontend.python.astutils import * + + +def function_to_ast(f): + """ Obtain the source code of a Python function and create an AST. + @param f: Python function. + @return: A 4-tuple of (AST, function filename, function line-number, + source code as string). + """ + try: + src = inspect.getsource(f) + # TypeError: X is not a module, class, method, function, traceback, frame, + # or code object; OR OSError: could not get source code + except (TypeError, OSError): + raise TypeError('cannot obtain source code for dace program') + + src_file = inspect.getfile(f) + _, src_line = inspect.findsource(f) + src_ast = ast.parse(_remove_outer_indentation(src)) + ast.increment_lineno(src_ast, src_line) + + return src_ast, src_file, src_line, src + + +def _remove_outer_indentation(src: str): + """ Removes extra indentation from a source Python function. + @param src: Source code (possibly indented). + @return: Code after de-indentation. + """ + lines = src.split('\n') + indentation = len(lines[0]) - len(lines[0].lstrip()) + return '\n'.join([line[indentation:] for line in lines]) + + +class FindLocals(ast.NodeVisitor): + """ Python AST node visitor that recovers all left-hand-side (stored) + locals. """ + + def __init__(self): + self.locals = {} + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Store): + self.locals[node.id] = node + + +def parse_dace_program(f, argtypes, global_vars, modules): + """ Parses a `@dace.program` function into a _ProgramNode object. + @param f: A Python function to parse. + @param argtypes: An iterable of tuples (name, type) for the given + function's arguments. + @param global_vars: A dictionary of global variables in the closure + of `f`. + @param modules: A dictionary from an imported module name to the + module itself. + @return: Hierarchical tree of `astnodes._Node` objects, where the top + level node is an `astnodes._ProgramNode`. + @rtype: astnodes._ProgramNode + """ + src_ast, src_file, src_line, src = function_to_ast(f) + + # Find local variables + local_finder = FindLocals() + local_finder.visit(src_ast) + local_vars = local_finder.locals + + # 1. Inline all "dace.call"ed functions + inliner = FunctionInliner(global_vars, modules, local_vars) + inliner.visit(src_ast) + + # 2. resolve all the symbols in the AST + allowed_globals = global_vars.copy() + allowed_globals.update(argtypes) + symresolver = SymbolResolver(allowed_globals) + symresolver.visit(src_ast) + + # 3. Parse the DaCe program to a hierarchical dependency representation + ast_parser = ParseDaCe(src_file, src_line, argtypes, global_vars, modules, + symresolver) + ast_parser.visit(src_ast) + pdp = ast_parser.program + pdp.source = src + pdp.filename = src_file + pdp.param_syms = sorted(symbolic.getsymbols(argtypes.values()).items()) + pdp.argtypes = argtypes + + return pdp + + +class MemletRemover(ExtNodeTransformer): + """ A Python AST transformer that removes memlet expressions of the type + `a << b[c]` and `d >> e(f)[g]`. """ + + def visit_TopLevelExpr(self, node): + # This is a DaCe shift, omit it + if isinstance(node.value, ast.BinOp): + if isinstance(node.value.op, ast.LShift) or isinstance( + node.value.op, ast.RShift): + return None + return self.generic_visit(node) + + +class ModuleInliner(ExtNodeTransformer): + """ A Python AST transformer that renames modules from their imported alias + to their actual name. """ + + def __init__(self, modules): + self.modules = modules + + def visit_Attribute(self, node): + attrname = rname(node) + module_name = attrname[:attrname.rfind('.')] + if module_name in self.modules: # math or equivalent modules + modname = self.modules[module_name] + node.value = ast.copy_location( + ast.Name(id=(modname), ctx=ast.Load), node.value) + return node + return self.generic_visit(node) + + +# Parses a DaCe program +class ParseDaCe(ExtNodeVisitor): + """ A Python AST visitor that creates DaCe program trees. + @see: parse_dace_program + """ + + def __init__(self, filename, lineoffset, argtypes, global_vars, modules, + symresolver): + self.curnode = None + self.program_name = None + self.filename = filename + self.lineoffset = lineoffset + self.argtypes = argtypes + self.modules = modules + self.globals = global_vars + self.symresolver = symresolver + + # Maps: {array name: data.Data)} + self.global_arrays = OrderedDict() + self.global_arrays.update(argtypes) + + # Entry point to the program + self.program = None + + ############################################################### + # Helper functions + ############################################################### + def _get_module(self, node): + try: + fullmodname = inspect.getmodule(eval(unparse(node), + self.globals)).__name__ + except NameError: + fullmodname = '' + # Only use the top-level module + if fullmodname.find('.') >= 0: + return fullmodname[:fullmodname.find('.')] + return fullmodname + + def _inner_eval_ast(self, node, additional_syms=None): + code = unparse(node) + syms = {} + syms.update(self.curnode.globals) + if additional_syms is not None: + syms.update(additional_syms) + + # First try to evaluate normally + try: + return eval(code, syms) + except: # Literally anything can happen here + # If doesn't work, try to evaluate as a sympy expression + # Replace subscript expressions with function calls (sympy support) + code = code.replace('[', '(') + code = code.replace(']', ')') + return symbolic.pystr_to_symbolic(code) + + def _compile_ast(self, node_body, line_offset, filename): + self.symresolver.visit(node_body) + wrapper = ast.Module(body=[node_body]) + + if line_offset is not None: + for node in ast.walk(wrapper): + node.lineno = line_offset + node.col_offset = 0 + + codeobj = compile(wrapper, filename, 'exec') + gen_module = {} + gen_module.update(self.globals) + exec(codeobj, gen_module) + return gen_module[node_body.name] + + def _eval_ast(self, node): + if node is None: + return None + elif isinstance(node, ast.Call): + # Only work on allowed functions and external functions according to + # decision flowchart for intra-program function evaluation: + # 1. Does it exist in the same program + already parsed? + # 2. Is it a @dace.external_function? + # 3. Is it one of the standard functions from the allowed module? + # 4. If neither of the previous, fail + func = rname(node) + + # Function call to a tasklet defined within the same program + if func in self.curnode.globals and isinstance( + self.curnode.globals[func], ast.FunctionDef): + # Since the function is never compiled by Python, we need to + # do so ourselves + compiled_func = self._compile_ast( + self.curnode.globals[func], self.lineoffset, self.filename) + return self._inner_eval_ast(node, {func: compiled_func}) + + # Standard function call, e.g., int(), math.sin() + elif self._get_module(node.func) in self.modules: + return self._inner_eval_ast(node) + + # External function calls + elif func in self.globals: + if isinstance(self.globals[func], types._external_function): + # External function needs to be recompiled with current + # symbols + src_ast, src_file, src_line, src = function_to_ast( + self.globals[func].func) + compiled_func = self._compile_ast(src_ast.body[0], + src_line, src_file) + return self._inner_eval_ast(node, {func: compiled_func}) + else: + return self._inner_eval_ast(node) + + else: + return self._inner_eval_ast(node) + elif isinstance(node, ast.FunctionDef): + compiled_sdfg = self._compile_ast(node, node.lineno, self.filename) + return compiled_sdfg.to_sdfg() + else: + # Not a function, try to evaluate + return self._inner_eval_ast(node) + + # Track local variables + def _set_locals(self): + if self.curnode.parent is None: + # Handle parameters (first set all to symbols, then set type + # descriptors for arrays) + self.curnode.globals.update( + {k: symbolic.symbol(k) + for k in self.curnode.params}) + self.curnode.globals.update(self.globals) + self.curnode.globals.update(self.global_arrays) + else: + self.curnode.globals.update(self.curnode.parent.globals) + self.curnode.globals.update( + {k: symbolic.symbol(k) + for k in self.curnode.params}) + + # Helper function to find the dtype of an array, either as a keyword or + # as the last parameter + def getarg_or_kwarg(self, node, argoff, argname): + if len(node.args) > argoff: + return node.args[argoff] + for k in node.keywords: + if rname(k) == argname: + return k.value + return None + + ############################################################### + # Parsing functions + ############################################################### + + def _ndslice_to_subset(self, ndslice): + is_tuple = [isinstance(x, tuple) for x in ndslice] + if not any(is_tuple): + return subsets.Indices(ndslice) + else: + if not all(is_tuple): + # If a mix of ranges and indices is found, convert to range + for i in range(len(ndslice)): + if not is_tuple[i]: + ndslice[i] = (ndslice[i], ndslice[i], 1) + return subsets.Range(ndslice) + + def _fill_missing_slices(self, ast_ndslice, array, indices): + # Filling ndslice with default values from array dimensions + # if ranges not specified (e.g., of the form "A[:]") + ndslice = [None] * len(ast_ndslice) + ndslice_size = 1 + offsets = [] + idx = 0 + for i, dim in enumerate(ast_ndslice): + if isinstance(dim, tuple): + rb = self._eval_ast(dim[0]) + re = self._eval_ast(dim[1]) + if re is not None: + re -= 1 + rs = self._eval_ast(dim[2]) + if rb is None: rb = 0 + if re is None: re = array.shape[indices[idx]] - 1 + if rs is None: rs = 1 + ndslice[i] = (rb, re, rs) + offsets.append(i) + idx += 1 + else: + ndslice[i] = self._eval_ast(dim) + + return ndslice, offsets + + # Parses a memlet statement + def ParseMemlet(self, local_name, rhsnode): + rhs = rname(rhsnode) + if rhs.find('.') >= 0: # attribute, form G.out_edges[:] + arrname = rhs[:rhs.find('.')] + arrattr = rhs[rhs.find('.') + 1:] + else: # normal memlet, form A(1)[i,j] + arrname = rhs + arrattr = None + + array = self.curnode.globals[arrname] + + # Determine number of accesses to the memlet (default is the slice size) + num_accesses = None + write_conflict_resolution = None + wcr_identity = None + # Detects expressions of the form "A(2)[...]", "A(300)", "A(1, sum)[:]" + if isinstance(rhsnode, ast.Call): + if len(rhsnode.args) < 1 or len(rhsnode.args) > 3: + raise DaCeSyntaxError( + self, rhsnode, + 'Number of accesses in memlet must be a number, symbolic ' + 'expression, or -1') + num_accesses = self._eval_ast(rhsnode.args[0]) + if len(rhsnode.args) >= 2: + write_conflict_resolution = rhsnode.args[1] + if len(rhsnode.args) >= 3: + wcr_identity = ast.literal_eval(rhsnode.args[2]) + elif isinstance(rhsnode, ast.Subscript) and isinstance( + rhsnode.value, ast.Call): + if len(rhsnode.value.args) < 1 or len(rhsnode.value.args) > 3: + raise DaCeSyntaxError( + self, rhsnode, + 'Number of accesses in memlet must be a number, symbolic ' + 'expression, or -1') + num_accesses = self._eval_ast(rhsnode.value.args[0]) + if len(rhsnode.value.args) >= 2: + write_conflict_resolution = rhsnode.value.args[1] + if len(rhsnode.value.args) >= 3: + wcr_identity = ast.literal_eval(rhsnode.value.args[2]) + + array_dependencies = {} + + # Get memlet range + ndslice = [(0, s - 1, 1) for s in array.shape] + if isinstance(rhsnode, ast.Subscript): + # Parse and evaluate ND slice(s) (possibly nested) + ast_ndslices = subscript_to_ast_slice_recursive(rhsnode) + offsets = list(range(len(array.shape))) + + # Loop over nd-slices (A[i][j][k]...) + subset_array = [] + for ast_ndslice in ast_ndslices: + # Loop over the N dimensions + ndslice, offsets = self._fill_missing_slices( + ast_ndslice, array, offsets) + subset_array.append(self._ndslice_to_subset(ndslice)) + + subset = subset_array[0] + + # Compose nested indices, e.g., of the form "A[i,:,j,:][k,l]" + for i in range(1, len(subset_array)): + subset = subset.compose(subset_array[i]) + + # Compute additional array dependencies (as a result of + # indirection) + for dim in subset: + if not isinstance(dim, tuple): dim = [dim] + for r in dim: + for expr in symbolic.swalk(r): + if symbolic.is_sympy_userfunction(expr): + arr = expr.func.__name__ + array_dependencies[arr] = self.curnode.globals[arr] + + else: # Use entire range + subset = self._ndslice_to_subset(ndslice) + + # If undefined, default number of accesses is the slice size + if num_accesses is None: + num_accesses = subset.num_elements() + + # This is a valid DaCe load/store, register it + return astnodes._Memlet( + array, arrname, arrattr, num_accesses, write_conflict_resolution, + wcr_identity, subset, 1, local_name, rhsnode, array_dependencies) + + # Helper function: parses DaCe array statement + def ParseArrayStatement(self, node, bInput): + if self.curnode is None: + raise DaCeSyntaxError( + self, node, + 'DaCe load/store statement declared outside function bounds') + + lhs = rname(node.value.left) + rhs = rname(node.value.right) + + if rhs.find('.') >= 0: # attribute, form G.out_edges[:] + arrname = rhs[:rhs.find('.')] + arrattr = rhs[rhs.find('.') + 1:] + else: # normal memlet, form A(1)[i,j] + arrname = rhs + arrattr = None + + arrays = self.curnode.arrays() + + # If this is not an undefined symbol (and the rhs is not a DaCe array), + # this is just a regular shift + if lhs in self.curnode.locals: + if arrname not in arrays: + return + else: + raise DaCeSyntaxError( + self, node, + 'Cannot load/store DaCe variable using an existing symbol') + + if arrname not in arrays: + raise DaCeSyntaxError( + self, node, 'Cannot load/store DaCe variable "' + arrname + + '" from a non-DaCe array') + + lhs_name = lhs + if lhs in arrays: + lhs = arrays[lhs] + + # Make sure the DaCe assignment is unique + if lhs in self.curnode.inputs: + raise DaCeSyntaxError( + self, node, 'Variable already assigned to another input') + if lhs in self.curnode.outputs: + raise DaCeSyntaxError( + self, node, 'Variable already assigned to another output') + + ######################## + # Determine the properties of the memlet + memlet = self.ParseMemlet(lhs_name, node.value.right) + + if bInput: + self.curnode.inputs[lhs_name] = memlet + else: + self.curnode.outputs[lhs_name] = memlet + + def ParseCallAssignment(self, node, target): + funcname = rname(node.func) + modname = self._get_module(node.func) + + ###################################### + # Handle DaCe-specific calls + if modname == 'dace': # modname is already the real name of the module + # Do not allow instantiation of ND arrays and DaCe scalars + if funcname == "ndarray" or funcname == "scalar": + raise DaCeSyntaxError( + self, node, + 'Cannot define a DaCe array within a program, try using ' + 'dace.define_local or dace.define_local_scalar') + + # Handle transient variables + if funcname.endswith(".define_local"): + if len(node.args) < 1: + raise DaCeSyntaxError( + self, node, + 'Invalid call to define_local, at least 1 parameter ' + 'is required') + if self.getarg_or_kwarg(node, 1, 'dtype') is None: + raise DaCeSyntaxError( + self, node, + 'Transient variable declaration must specify type') + + # Construct type descriptor + shape = self._eval_ast(node.args[0]) + dtype = self._eval_ast(self.getarg_or_kwarg(node, 1, 'dtype')) + allow_conflicts = self._eval_ast( + self.getarg_or_kwarg(node, 2, 'allow_conflicts')) + allow_conflicts = False if allow_conflicts is None else True + try: + tdesc = data.Array( + dtype, + shape, + transient=True, + allow_conflicts=allow_conflicts) + except TypeError as ex: + raise DaCeSyntaxError(self, node, str(ex)) + + self.curnode.transients[rname(target)] = tdesc + self.curnode.globals[rname(target)] = tdesc + return None + + elif funcname.endswith(".define_local_scalar"): + if self.getarg_or_kwarg(node, 0, 'dtype') is None: + raise DaCeSyntaxError( + self, node, + 'Transient variable declaration must specify type') + + # Construct type descriptor + dtype = self._eval_ast(self.getarg_or_kwarg(node, 0, 'dtype')) + allow_conflicts = self._eval_ast( + self.getarg_or_kwarg(node, 1, 'allow_conflicts')) + allow_conflicts = False if allow_conflicts is None else True + + tdesc = data.Scalar( + dtype, transient=True, allow_conflicts=allow_conflicts) + + self.curnode.transients[rname(target)] = tdesc + self.curnode.globals[rname(target)] = tdesc + return None + elif funcname.endswith(".define_stream") or funcname.endswith( + ".define_streamarray"): + argOffset = 0 + if funcname.endswith('array'): + # Defined stream array, expecting shape + shape = self._eval_ast( + self.getarg_or_kwarg(node, 0, 'shape')) + argOffset += 1 + else: + shape = [1] + + dtype = self._eval_ast( + self.getarg_or_kwarg(node, argOffset, 'dtype')) + + # Optional parameters + internal_size = self._eval_ast( + self.getarg_or_kwarg(node, argOffset + 1, 'buffer_size')) + + tdesc = data.Stream( + dtype, 1, internal_size, shape=shape, transient=True) + + self.curnode.transients[rname(target)] = tdesc + self.curnode.globals[rname(target)] = tdesc + return None + elif (funcname.rfind('.') != -1 + and funcname[funcname.rfind('.') + + 1:] in types.TYPECLASS_STRINGS): + return node + else: + raise DaCeSyntaxError( + self, node, 'Unrecognized function call \'%s\'' % funcname) + ###################################### + # Other calls are treated as memlet functions (independent of arrays, + # inline-able) + else: + return node + + def _add_possible_inputs(self, nodes, prim): + if not isinstance(nodes, list): + nodes = [nodes] + extended_nodes = [] + final_nodes = [] + + # Extract values from lists, tuples and subsets + for node in nodes: + if isinstance(node, tuple): + final_nodes.extend(list(node)) + elif isinstance(node, subsets.Range): + for dim in node.ranges: + final_nodes.extend(list(dim)) + elif isinstance(node, subsets.Indices): + final_nodes.extend(list(node)) + + # Find AST names + for node in extended_nodes: + if isinstance(node, ast.AST): + for subnode in ast.walk(node): + if isinstance(subnode, ast.Name): + final_nodes.append(subnode.id) + else: + final_nodes.append(node) + nodeset = set() + for n in final_nodes: + if symbolic.issymbolic(n): + nodeset.update(str(s) for s in n.free_symbols) + elif isinstance(n, str): + nodeset.add(n) + + arrs = self.curnode.arrays() + for input in nodeset: + if input in arrs: + inName = '__DACEIN_' + input + prim.inputs[inName] = astnodes._Memlet( + arrs[input], input, None, 1, None, None, + subsets.Indices([0]), 1, None, None, {}) + + ############################################################### + # AST visiting functions + ############################################################### + + def visit_FunctionDef(self, node, is_async=False): + # Obtain function name + parent_node = self.curnode + curprim = None + + arrays = OrderedDict() + if self.curnode is not None: + arrays = self.curnode.arrays() + + # Obtain program/primitive name (only one program is allowed) + if (len(node.decorator_list) > 0): + if (len(node.decorator_list) > 1): + raise DaCeSyntaxError(self, node, + 'Only one DaCe decorator is allowed') + + # Make sure that the module is DaCe + dec = node.decorator_list[0] + decname = rname(dec) + if isinstance(dec, ast.Call): + modname = self._get_module(dec.func) + else: + modname = self._get_module(dec) + if modname not in self.modules.values() or modname != 'dace': + raise DaCeSyntaxError( + self, node, + 'Decorators from module \'%s\' not allowed' % modname) + ##################################### + + # Create DaCe program node + if decname.endswith('.program'): + if self.program is not None: + # Parse internal program separately as an SDFG of its own + sdfg = self._eval_ast(node) + curprim = astnodes._NestedSDFGNode(node.name, node, sdfg) + + # Inherit I/O from immediate parent + curprim.inputs = copy.copy(parent_node.inputs) + curprim.outputs = copy.copy(parent_node.outputs) + # Cancel parent node's relevant I/O + parent_node.inputs.clear() + parent_node.outputs.clear() + + # Set children of parent primitive, if it is a primitive + if parent_node is not None and curprim is not None: + parent_node.children.append(curprim) + curprim.parent = parent_node + + # Exit so that child AST nodes will not be parsed + return + + self.program = astnodes._ProgramNode(node.name, node) + curprim = self.program + + # Parse primitives + # Dataflow primitives + elif decname.endswith('map'): + curprim = astnodes._MapNode(node.name, node) + + # If the arguments are defined in the decorator + if 'args' in dir(dec) and len(dec.args) > 0: + curprim.range = subsets.Range( + subscript_to_slice(dec.args[0], arrays)[1]) + else: + try: + curprim.range = subsets.Range([ + subscript_to_slice(arg.annotation, arrays)[1][0] + for arg in node.args.args + ]) + except (AttributeError, TypeError, ValueError): + raise DaCeSyntaxError( + self, node, + 'All arguments in DaCe primitive %s must be annotated with a range' + % node.name) + self._add_possible_inputs(curprim.range, curprim) + + elif decname.endswith('consume'): + curprim = astnodes._ConsumeNode(node.name, node) + + # If the arguments are defined in the decorator + if 'args' in dir(dec) and len(dec.args) > 0: + if dec.args[0].id not in self.curnode.globals: + raise DaCeSyntaxError( + self, node, 'Undefined stream %s in consume %s' % + (dec.args[0].id, node.name)) + curprim.stream = self.curnode.globals[rname(dec.args[0])] + ast_memlet = self.ParseMemlet(node.args.args[0].arg, + dec.args[0]) + ast_memlet.num_accesses = -1 + curprim.inputs[node.args.args[0].arg] = ast_memlet + if len(dec.args) < 2: + raise DaCeSyntaxError( + self, node, + 'Consume %s missing required argument: ' + 'number of processing elements' % node.name) + curprim.num_pes = symbolic.pystr_to_symbolic( + unparse(dec.args[1])) + if len(dec.args) > 2: + curprim.condition = unparse(dec.args[2]) + else: + curprim.condition = None + else: + raise DaCeSyntaxError( + self, node, + 'Consume syntax only supports parameters at the ' + 'decorator') + self._add_possible_inputs(curprim.stream, curprim) + + elif decname.endswith('tasklet'): + # Parse arguments + lang = None + gcode = None + if isinstance(dec, ast.Call): + lang = self._eval_ast( + self.getarg_or_kwarg(dec, 0, 'language')) + gcode = self._eval_ast( + self.getarg_or_kwarg(dec, 1, 'global_code')) + + if lang is None: + lang = types.Language.Python + else: + try: + lang = types.Language[lang] + except KeyError: + raise DaCeSyntaxError( + self, node, + 'Unrecognized tasklet language "%s"' % lang) + if gcode is None: + gcode = '' + + curprim = astnodes._TaskletNode(node.name, node, lang, gcode) + + # Control flow primitives + elif decname.endswith('iterate'): + if isinstance(parent_node, astnodes._DataFlowNode): + raise DaCeSyntaxError( + self, node, 'Control flow within data flow disallowed') + + curprim = astnodes._IterateNode(node.name, node) + + if 'args' in dir(dec) and len( + dec.args + ) > 0: # If the arguments are defined in the decorator + curprim.range = subsets.Range( + subscript_to_slice(dec.args[0], arrays)[1]) + else: + try: + curprim.range = subsets.Range([ + subscript_to_slice(arg.annotation, arrays)[1][0] + for arg in node.args.args + ]) + except (AttributeError, TypeError, ValueError): + raise SyntaxError( + 'All arguments in DaCe primitive %s must be annotated with a range' + % node.name) + self._add_possible_inputs(curprim.range, curprim) + + elif decname.endswith('loop'): + if isinstance(parent_node, astnodes._DataFlowNode): + raise DaCeSyntaxError( + self, node, 'Control flow within data flow disallowed') + + curprim = astnodes._LoopNode(node.name, node) + + if 'args' in dir(dec) and len( + dec.args + ) > 0: # If the arguments are defined in the decorator + curprim.condition = dec.args[0] + else: + raise SyntaxError( + 'Condition must be given as argument to decorator in DaCe primitive %s' + % node.name) + self._add_possible_inputs(curprim.condition, curprim) + + elif decname.endswith('conditional'): + if isinstance(parent_node, astnodes._DataFlowNode): + raise DaCeSyntaxError( + self, node, 'Control flow within data flow disallowed') + + curprim = astnodes._ConditionalNode(node.name, node) + + if 'args' in dir(dec) and len( + dec.args + ) > 0: # If the arguments are defined in the decorator + curprim.condition = dec.args[0] + else: + raise SyntaxError( + 'Condition must be given as argument to decorator in DaCe primitive %s' + % node.name) + self._add_possible_inputs(curprim.condition, curprim) + + else: + raise DaCeSyntaxError(self, node, + 'Unrecognized primitive ' + decname) + + if '.async_' in decname or is_async: + curprim.is_async = True + # End of program/primitive name + + # If this is a primitive + if curprim is not None: + # If function definition contains arguments + if 'args' in dir(node): + for arg in node.args.args: + + # If it is not the program, add locals as symbols + if self.program != node.name: + curprim.globals[rname(arg)] = symbolic.symbol( + rname(arg)) + if curprim is not None: + curprim.params.append(rname(arg)) + + # Set children of parent primitive, if it is a primitive + if parent_node is not None and curprim is not None: + parent_node.children.append(curprim) + curprim.parent = parent_node + + self.curnode = curprim + + # Track local variables + self._set_locals() + + # Mandatory (to keep visiting children) + for stmt in node.body: + self.visit_TopLevel(stmt) + + # After traversing the function, pop "function name stack" + self.curnode = parent_node + else: # Not a primitive + self.curnode.locals[node.name] = node + self.curnode.globals[node.name] = node + + # Mandatory (to keep visiting children) + for stmt in node.body: + self.visit_TopLevel(stmt) + + def visit_AsyncFunctionDef(self, node): + # Treat as a plain function + self.visit_FunctionDef(node, is_async=True) + + def visit_Call(self, node): + if (not isinstance(node.func, ast.Attribute) + or node.func.value.id not in self.modules + or self.modules[node.func.value.id] != 'dace'): + self.generic_visit(node) + return + + # Reduce call + if node.func.attr.endswith('reduce'): + dec = node + # Mandatory arguments + wcr = dec.args[0] + src = dec.args[1] + dst = dec.args[2] + # In case the axis argument is given without explicit kwarg + # notation + axisarg = dec.args[3] if len(dec.args) > 3 else None + identityarg = dec.args[4] if len(dec.args) > 4 else None + + curprim = astnodes._ReduceNode('reduce', wcr) + curprim.axes = get_tuple(self, getkwarg(dec, 'axis', axisarg)) + curprim.identity = get_tuple( + self, getkwarg(dec, 'identity', identityarg)) + if curprim.identity is not None: + curprim.identity = curprim.identity[0] + curprim.inputs['input'] = self.ParseMemlet('input', src) + curprim.outputs['output'] = self.ParseMemlet('output', dst) + + # Set children of parent primitive, if it is a primitive + self.curnode.children.append(curprim) + curprim.parent = self.curnode + + def visit_TopLevelExpr(self, node): + if isinstance(node.value, ast.BinOp): + if (isinstance(node.value.op, ast.LShift)): + self.ParseArrayStatement(node, True) + return + if (isinstance(node.value.op, ast.RShift)): + self.ParseArrayStatement(node, False) + return + elif isinstance(node.value, ast.Str): + self.visit_TopLevelStr(node.value) + return + + self.generic_visit(node) + + def visit_TopLevelStr(self, node): + if isinstance(self.curnode, astnodes._TaskletNode): + if self.curnode.extcode != None: + raise DaCeSyntaxError( + self, node, + 'Cannot provide more than one intrinsic implementation ' + + 'for tasklet') + self.curnode.extcode = node.s + return + + self.generic_visit(node) + + # Detect locals and transient variables + def visit_Assign(self, node): + # Don't allow assignment to tuples (for now) + if len(node.targets) > 1: + raise DaCeSyntaxError(self, node, + 'Assignment to tuples not supported (yet)') + target = node.targets[0] + if isinstance(target, ast.Tuple): + if len(target.elts) > 1: + raise DaCeSyntaxError( + self, node, 'Assignment to tuples not supported (yet)') + target = target.elts[0] + + # Tasklet code + if self.curnode is not None: + if isinstance(node.value, ast.Call) and\ + not isinstance(self.curnode, astnodes._TaskletNode): + retval = self.ParseCallAssignment(node.value, target) + if retval is not None: + self.curnode.locals[rname(target)] = retval + self.curnode.globals[rname(target)] = retval + + # No need to further visit the node's children + return + else: + if isinstance(self.curnode, astnodes._DataFlowNode): + self.curnode.locals[rname(target)] = None + self.curnode.globals[rname(target)] = None + else: + retval = self._eval_ast(node.value) + self.curnode.locals[rname(target)] = retval + self.curnode.globals[rname(target)] = retval + + # No need to further visit the node's children + return + + self.generic_visit(node) + + # Visit statements that define locals + def visit_Name(self, node): + if self.curnode is None: + arrays = self.global_arrays + else: + arrays = self.curnode.arrays() + + if node.id in arrays and (not isinstance(arrays[node.id], data.Scalar) + or arrays[node.id].transient): + if isinstance(node.ctx, ast.Load) or isinstance( + node.ctx, ast.Store): + raise DaCeSyntaxError( + self, node, + 'Directly reading from and writing to arrays is not ' + 'allowed. Please use memlet notation (a << A[i])') + + self.generic_visit(node) + + # Control flow blocks + ######################### + def visit_For(self, node): + # Syntax: Only accept for loops without 'else'; only accept for loops + # with structure 'for in range()' + if len(node.orelse) > 0: + raise DaCeSyntaxError( + self, node, + 'Loops with \'else\' footer are not allowed in DaCe programs') + + if self.curnode is not None: + # Verify syntax + ######################################################## + # We allow only three types of for loops: + # 1. `for i in range(...)`: Creates a looping state + # 2. `for i in parrange(...)`: Creates a 1D map + # 3. `for i,j,k in dace.map[0:M, 0:N, 0:K]`: Creates an ND map + + if isinstance(node.iter, ast.Call): + funcname = rname(node.iter.func) + modname = self._get_module(node.iter.func) + elif isinstance(node.iter, ast.Subscript): + funcname = rname(node.iter.value) + modname = self._get_module(node.iter.value) + else: + funcname, modname = None, None + + # Case 1: Iterate + if (isinstance(node.target, ast.Name) + and isinstance(node.iter, ast.Call) + and isinstance(node.iter.func, ast.Name) + and node.iter.func.id == 'range'): + # If we are inside a dataflow construct, ignore + if isinstance(self.curnode, astnodes._DataFlowNode): + self.generic_visit(node) + return + + # Obtain parameters + varname = node.target.id + nargs = len(node.iter.args) + var_rb = 0 if nargs < 2 else symbolic.pystr_to_symbolic( + unparse(node.iter.args[0])) + var_re = (symbolic.pystr_to_symbolic( + unparse(node.iter.args[1])) + if nargs > 1 else symbolic.pystr_to_symbolic( + unparse(node.iter.args[0]))) - 1 + var_rs = 1 if nargs < 3 else symbolic.pystr_to_symbolic( + unparse(node.iter.args[2])) + + # Create node + curprim = astnodes._IterateNode('iterate_' + str(node.lineno), + node) + curprim.range = [(var_rb, var_re, var_rs)] + curprim.params = [varname] + self._add_possible_inputs(curprim.range, curprim) + self.curnode.children.append(curprim) + curprim.parent = self.curnode + + # Traverse into loop + oldnode = self.curnode + self.curnode = curprim + self._set_locals() + for stmt in node.body: + self.visit(stmt) + self.curnode = oldnode + #################### + return + + # Case 2: 1D map (for i in parrange(...)) + elif (isinstance(node.target, ast.Name) + and isinstance(node.iter, ast.Call) + and isinstance(node.iter.func, ast.Name) + and node.iter.func.id == 'parrange'): + curprim = astnodes._MapNode('map_' + str(node.lineno), node) + + # Get arguments for range + maprange = [] + if len(node.iter.args) == 1: # end only + maprange = [(None, node.iter.args[0], None)] + elif len(node.iter.args) == 2: # begin, end + maprange = [(node.iter.args[0], node.iter.args[1], None)] + elif len(node.iter.args) == 3: # begin, end, skip + maprange = [(node.iter.args[0], node.iter.args[1], + node.iter.args[2])] + else: + raise DaCeSyntaxError( + self, node, + 'Invalid number of arguments for "parrange"') + + curprim.range = subsets.Range( + astrange_to_symrange(maprange, self.curnode.arrays())) + curprim.params = [rname(node.target)] + + self._add_possible_inputs(curprim.range, curprim) + self.curnode.children.append(curprim) + curprim.parent = self.curnode + + # Traverse into loop + oldnode = self.curnode + self.curnode = curprim + self._set_locals() + for stmt in node.body: + self.visit(stmt) + self.curnode = oldnode + #################### + + return + + # Case 3: ND map + elif (isinstance(node.target, ast.Tuple) + and isinstance(node.iter, ast.Subscript) + and isinstance(node.iter.value, ast.Attribute) + and modname == 'dace' and node.iter.value.attr == 'map'): + curprim = astnodes._MapNode('map_' + str(node.lineno), node) + + # Get range from array subscript, check for length mismatch + _, range_values = subscript_to_slice(node.iter, + self.curnode.arrays()) + range_keys = [rname(n) for n in node.target.elts] + if len(range_keys) != len(range_values): + raise DaCeSyntaxError( + self, node, + 'Map range must match tuple length in for loop') + curprim.params = range_keys + curprim.range = subsets.Range(range_values) + + self._add_possible_inputs(curprim.range, curprim) + self.curnode.children.append(curprim) + curprim.parent = self.curnode + + # Traverse into loop + oldnode = self.curnode + self.curnode = curprim + self._set_locals() + for stmt in node.body: + self.visit(stmt) + self.curnode = oldnode + #################### + + return + + # No match + else: + raise DaCeSyntaxError( + self, node, 'Invalid loop syntax. Supported options are:\n' + ' for in range()\n' + ' for in parrange()\n' + ' for in dace.map[ranges]') + ####################################################### + + self.generic_visit(node) + + def visit_While(self, node): + # Syntax: Only accept while loops without 'else' + if len(node.orelse) > 0: + raise DaCeSyntaxError( + self, node, + 'Loops with \'else\' footer are not allowed in DaCe programs') + + if self.curnode is not None: + # If we are inside a dataflow construct, ignore + if not isinstance(self.curnode, astnodes._DataFlowNode): + # Obtain parameters + cond = node.test + + # Create node + curprim = astnodes._LoopNode('while_' + str(node.lineno), node) + curprim.condition = cond + self._add_possible_inputs(curprim.condition, curprim) + self.curnode.children.append(curprim) + curprim.parent = self.curnode + + # Traverse into loop + oldnode = self.curnode + self.curnode = curprim + self._set_locals() + for stmt in node.body: + self.visit(stmt) + self.curnode = oldnode + #################### + return + + self.generic_visit(node) + + def visit_If(self, node): + if self.curnode is not None: + # If we are inside a dataflow construct, ignore + if not isinstance(self.curnode, astnodes._DataFlowNode): + # Obtain parameters + cond = node.test + + # Create node + curprim = astnodes._IfNode('if_' + str(node.lineno), node) + curprim.condition = cond + self._add_possible_inputs(curprim.condition, curprim) + self.curnode.children.append(curprim) + curprim.parent = self.curnode + + # Traverse into condition + oldnode = self.curnode + self.curnode = curprim + self._set_locals() + for stmt in node.body: + self.visit(stmt) + self.curnode = oldnode + + # Process 'else'/'elif' statements + if len(node.orelse) > 0: + # Create node + curprim = astnodes._ElseNode( + 'else_' + str(node.orelse[0].lineno), node) + # Negate condition + curprim.condition = astutils.negate_expr(cond) + self.curnode.children.append(curprim) + curprim.parent = self.curnode + + # Traverse into condition + oldnode = self.curnode + self.curnode = curprim + self._set_locals() + for stmt in node.orelse: + self.visit(stmt) + self.curnode = oldnode + + return + + self.generic_visit(node) + + def visit_With(self, node, is_async=False): + # "with dace.tasklet" syntax + if len(node.items) == 1: + dec = node.items[0].context_expr + if isinstance(dec, ast.Call): + funcname = rname(dec.func) + modname = self._get_module(dec.func) + elif isinstance(dec, ast.Attribute): + funcname = rname(dec) + modname = self._get_module(dec) + else: + funcname, modname = None, None + + if modname == 'dace' and funcname.endswith('.tasklet'): + # Parse as tasklet + # NOTE: This is almost a direct copy of the tasklet parser + # above. + lang = None + gcode = None + if isinstance(dec, ast.Call): + lang = self._eval_ast( + self.getarg_or_kwarg(dec, 0, 'language')) + gcode = self._eval_ast( + self.getarg_or_kwarg(dec, 1, 'global_code')) + + if lang is None: + lang = types.Language.Python + else: + try: + lang = types.Language[lang] + except KeyError: + raise DaCeSyntaxError( + self, node, + 'Unrecognized tasklet language "%s"' % lang) + if gcode is None: + gcode = '' + + curprim = astnodes._TaskletNode('tasklet_' + str(node.lineno), + node, lang, gcode) + if self.curnode is not None: + self.curnode.children.append(curprim) + curprim.parent = self.curnode + + # Traverse into tasklet + oldnode = self.curnode + self.curnode = curprim + self._set_locals() + for stmt in node.body: + self.visit_TopLevel(stmt) + self.curnode = oldnode + return + + raise DaCeSyntaxError( + self, node, + 'General "with" statements disallowed in DaCe programs') + + ######################### + + ## Disallowed statements + def visit_Global(self, node): + raise DaCeSyntaxError( + self, node, '"global" statements disallowed in DaCe sub-programs') + + def visit_Delete(self, node): + raise DaCeSyntaxError(self, node, + '"del" statements disallowed in DaCe programs') + + def visit_Import(self, node): + raise DaCeSyntaxError(self, node, + 'imports disallowed in DaCe programs') + + def visit_ImportFrom(self, node): + raise DaCeSyntaxError(self, node, + 'imports disallowed in DaCe programs') + + def visit_Assert(self, node): + raise DaCeSyntaxError( + self, node, '"assert" statements disallowed in DaCe programs') + + def visit_Pass(self, node): + raise DaCeSyntaxError(self, node, + '"pass" statements disallowed in DaCe programs') + + def visit_Exec(self, node): + raise DaCeSyntaxError(self, node, + '"exec" statements disallowed in DaCe programs') + + def visit_Print(self, node): + raise DaCeSyntaxError( + self, node, '"print" statements disallowed in DaCe programs') + + def visit_Nonlocal(self, node): + raise DaCeSyntaxError( + self, node, '"nonlocal" statements disallowed in DaCe programs') + + def visit_Yield(self, node): + raise DaCeSyntaxError( + self, node, '"yield" statements disallowed in DaCe programs') + + def visit_YieldFrom(self, node): + raise DaCeSyntaxError( + self, node, '"yield" statements disallowed in DaCe programs') + + def visit_Raise(self, node): + raise DaCeSyntaxError(self, node, + 'exceptions disallowed in DaCe programs') + + def visit_Try(self, node): + raise DaCeSyntaxError(self, node, + 'exceptions disallowed in DaCe programs') + + def visit_TryExcept(self, node): + raise DaCeSyntaxError(self, node, + 'exceptions disallowed in DaCe programs') + + def visit_TryFinally(self, node): + raise DaCeSyntaxError(self, node, + 'exceptions disallowed in DaCe programs') + + def visit_ExceptHandler(self, node): + raise DaCeSyntaxError(self, node, + 'exceptions disallowed in DaCe programs') + + def visit_AsyncWith(self, node): + self.visit_With(node, is_async=True) + + def visit_Starred(self, node): + raise DaCeSyntaxError( + self, node, 'starred statements disallowed in DaCe programs') + + def visit_Ellipsis(self, node): + raise DaCeSyntaxError(self, node, + '"..." statements disallowed in DaCe programs') + + # disallowed for now + def visit_ClassDef(self, node): + raise DaCeSyntaxError(self, node, + 'classes disallowed (for now) in DaCe programs') + + def visit_AsyncFor(self, node): + raise DaCeSyntaxError( + self, node, + 'asynchronous loops disallowed (for now) in DaCe programs') + + def visit_Await(self, node): + raise DaCeSyntaxError(self, node, + 'await disallowed (for now) in DaCe programs') + + #Data structures + def visit_Bytes(self, node): + raise DaCeSyntaxError( + self, node, 'bytestrings disallowed (for now) in DaCe programs') + + def visit_Set(self, node): + raise DaCeSyntaxError(self, node, + 'sets disallowed (for now) in DaCe programs') + + def visit_Dict(self, node): + raise DaCeSyntaxError( + self, node, 'dictionaries disallowed (for now) in DaCe programs') + + #Comprehensions + def visit_ListComp(self, node): + raise DaCeSyntaxError(self, node, + 'comprehensions disallowed in DaCe programs') + + def visit_GeneratorExp(self, node): + raise DaCeSyntaxError(self, node, + 'comprehensions disallowed in DaCe programs') + + def visit_SetComp(self, node): + raise DaCeSyntaxError(self, node, + 'comprehensions disallowed in DaCe programs') + + def visit_DictComp(self, node): + raise DaCeSyntaxError(self, node, + 'comprehensions disallowed in DaCe programs') + + def visit_comprehension(self, node): + raise DaCeSyntaxError(self, node, + 'comprehensions disallowed in DaCe programs') + + def visit_ImportFrom(self, node): + raise DaCeSyntaxError(self, node, + 'imports disallowed in DaCe programs') + + +class ASTFindAndReplace(ast.NodeTransformer): + """ A Python AST transformer utility that finds and replaces names. """ + + def __init__(self, replacements, skip_subscripts=True): + self.replacement_dict = replacements + self.skip_subscripts = skip_subscripts + + def visit_Subscript(self, node): + # Do not visit subscripts that contain a replacement + if rname(node) in self.replacement_dict and self.skip_subscripts: + return node + self.generic_visit(node) + + def visit_Name(self, node): + if node.id in self.replacement_dict: + return ast.copy_location( + ast.Name(id=self.replacement_dict[node.id], ctx=node.ctx), + node) + + return self.generic_visit(node) + + +class SymbolResolver(astutils.ExtNodeTransformer): + """ Python AST transformer that resolves symbols to their name or + value. """ + + def __init__(self, symbols): + self.symbols = symbols + self.locals = {} + self.top_function = True + + def resolve(self, node): + if node is None: + return None + if isinstance(node, tuple): + return tuple(self.resolve(n) for n in node) + return unparse(self.visit(node)) + + def visit_FunctionDef(self, node): + oldlocals = {} + oldlocals.update(self.locals) + oldtop = self.top_function + + # Register parameters as locals + if not self.top_function: + for arg in node.args.args: + self.locals[rname(arg)] = arg + + self.top_function = False + result = self.generic_visit(node) + self.top_function = oldtop + + self.locals = oldlocals + + return result + + def visit_TopLevelExpr(self, node): + if isinstance(node.value, ast.BinOp): + if isinstance(node.value.op, ast.LShift) or isinstance( + node.value.op, ast.RShift): + self.locals[rname(node.value.left)] = node.value.left + + node.value.right = self.visit(node.value.right) + return node + + return self.generic_visit(node) + + def visit_Name(self, node): + # Defining a local + if isinstance(node.ctx, ast.Store): + # TODO(later): Scope management + # Example: + # n = 5 + # @dace.program + # def prog(): + # def inner(): + # n = dace.define_local(...) + # use n (should be "n") + # use n (should be 5) + + self.locals[node.id] = node + return node + + if node.id not in self.symbols: + return node + if node.id in self.locals: + return node + + sym = self.symbols[node.id] + if isinstance(sym, symbolic.symbol): + return ast.copy_location(ast.Name(id=sym.name, ctx=node.ctx), node) + elif isinstance(sym, types.typeclass): + # Find dace module name + dacemodule = next( + k for k, v in self.symbols.items() + if isinstance(v, type(types)) and v.__name__ == 'dace') + + return ast.copy_location( + ast.Attribute( + value=ast.Name(id=dacemodule, ctx=ast.Load()), + attr=sym.to_string(), + ctx=node.ctx), node) + elif types.isconstant(sym): + return ast.copy_location(ast.Num(n=sym, ctx=node.ctx), node) + elif isinstance(sym, ast.Name): + return ast.copy_location(ast.Name(id=sym.id, ctx=node.ctx), node) + elif isinstance(sym, ast.AST): + return ast.copy_location(sym, node) + else: + return node + + +########################################################################## +# Function inlining + + +class CounterDict(object): + """ Dictionary object that counts how many times a value was added to + it. """ + + def __init__(self): + self.values = {} + + def get(self, key): + if key in self.values: + return self.values[key] + else: + return 0 + + def add(self, key, count=1): + if key not in self.values: + self.values[key] = count + else: + self.values[key] += count + + +class FunctionInliner(ExtNodeTransformer): + """ A Python AST transformer that inlines functions called (e.g., with + "dace.call") in an existing AST. """ + + def __init__(self, global_vars, modules, local_vars={}): + self.globals = global_vars + self.locals = local_vars + self.modules = modules + self.function_inline_counter = CounterDict() + + def visit_Call(self, node): + cnode = node + + # First, visit arguments and (possibly) inline them. This takes care + # of "dace.call(func, dace.call(f2, arg), ...)" cases + node = self.generic_visit(node) + + # Only accept "dace.call" calls + if isinstance(cnode.func, ast.Attribute) and cnode.func.attr == 'call': + # Verify that the module is DaCe + if self.modules[cnode.func.value.id] == 'dace': + # INLINE + if len(cnode.args) < 1: + raise SyntaxError( + 'dace.call must have at least one parameter') + return self.inline_function(cnode, cnode.args[0]) + + return node + + # Inline top-level calls as well + def visit_TopLevelExpr(self, node): + if isinstance(node.value, ast.Call): + node.value = self.visit_TopLevelCall(node.value) + return node + return self.generic_visit(node) + + def _fname_and_module(self, funcnode): + funcmodule = None + if isinstance(funcnode, ast.Attribute): + funcmodule = funcnode.value.id + funcname = funcnode.attr + else: + funcname = funcnode.id + return (funcmodule, funcname) + + def visit_TopLevelCall(self, node): + # If dace.call(...) + if isinstance(node.func, ast.Attribute) and node.func.attr == 'call': + return self.visit_Call(node) + + funcmodule, funcname = self._fname_and_module(node.func) + if funcmodule is None and funcname in self.globals: + # First, visit arguments and (possibly) inline them. This takes care + # of "dace.call(func, dace.call(f2, arg), ...)" cases + node = self.generic_visit(node) + + return self.inline_function(node, node.func) + + return self.generic_visit(node) + + def _transients_from_ast(self, src_ast): + results = set() + for astnode in ast.walk(src_ast): + if (isinstance(astnode, ast.Assign) + and isinstance(astnode.value, ast.Call)): + modulename, _ = self._fname_and_module(astnode.value.func) + if (modulename is not None + and self.modules[modulename] == 'dace'): + # Don't allow assignment to tuples (for now) + if len(astnode.targets) > 1: + raise DaCeSyntaxError( + self, astnode, + 'Assignment to tuples not supported (yet)') + target = astnode.targets[0] + if isinstance(target, ast.Tuple): + if len(target.elts) > 1: + raise DaCeSyntaxError( + self, node, + 'Assignment to tuples not supported (yet)') + target = target.elts[0] + + results.add(rname(target)) + return results + + def inline_function(self, cnode, funcnode): + funcmodule, funcname = self._fname_and_module(funcnode) + if funcmodule is None and funcname not in self.globals: + raise SyntaxError( + 'Function %s not found (is it declared as @dace.external_function?)' + % funcname) + if funcmodule is not None: + raise SyntaxError('External DaCe functions should be' + + ' imported directly using "from ' + + ' import ..."') + + self.function_inline_counter.add(funcname) + + # Obtain the function object + f = None + if isinstance(self.globals[funcname], types._external_function): + f = self.globals[funcname].func + else: + f = self.globals[funcname] + + # Parse that function's AST + src_ast, src_file, src_line, src = function_to_ast(f) + + # Inline the function's intenal dace.calls recursively + for astnode in ast.walk(src_ast): + if isinstance(astnode, ast.Call): + src_ast = FunctionInliner(self.globals, + self.modules).visit(src_ast) + break + + # Replace the function's parameters with the values in arguments + func_args = src_ast.body[0].args.args + + if cnode.func == funcnode: # In case of calling a function directly + call_args = cnode.args[:] + else: # In case of calling a function through dace.call + call_args = cnode.args[1:] + if len(func_args) != len(call_args): + raise SyntaxError( + 'Mismatch in arguments to call %s. Expecting %d, got %d' % + (f.__name__, len(func_args), len(call_args))) + + replacement_map = { # parameter replacement map + rname(k): v + for k, v in zip(func_args, call_args) + } + + # Obtain and rename transients as well. "tmp" --> "func0_tmp" + local_replacement_map = { + k: ast.Name( + ctx=ast.Load(), + id='%s%d_%s' % (funcname, + self.function_inline_counter.get(funcname), k)) + for k in self._transients_from_ast(src_ast) + } + for replacement in local_replacement_map.values(): + for repl_ast in ast.walk(replacement): + if isinstance(repl_ast, ast.Name): + if (repl_ast.id in self.globals + or repl_ast.id in self.locals): + raise SyntaxError( + ('Cannot name a symbol %s due to function ' + + 'inlining, please choose another name') % + repl_ast.id) + replacement_map.update(local_replacement_map) + + src_ast = SymbolResolver(replacement_map).visit(src_ast) + + # If the function has a return statement, then we need to + # evaluate the AST instead + if any(isinstance(stmt, ast.Return) for stmt in ast.walk(src_ast)): + if len(src_ast.body[0].body) > 1: + raise NotImplementedError( + "Functions with return value and more than one statement are not implemented" + ) + + # Inline the function by replacing the return value + return src_ast.body[0].body[0].value + + return src_ast.body[0].body diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py new file mode 100644 index 0000000000..ac16dffe6f --- /dev/null +++ b/dace/frontend/python/astutils.py @@ -0,0 +1,317 @@ +""" Various AST parsing utilities for DaCe. """ +import ast +import astunparse +import sympy + +from dace import types, symbolic + + +def rname(node): + """ Obtains names from different types of AST nodes. """ + + if isinstance(node, str): + return node + if isinstance(node, ast.Num): + return str(node.n) + if isinstance(node, ast.Name): # form x + return node.id + if isinstance(node, ast.Subscript): # form A[a:b,...,c:d] + return rname(node.value) + if isinstance(node, ast.Attribute): # form @dace.attr_noparams + return rname(node.value) + '.' + rname(node.attr) + if isinstance(node, ast.Call): # form @dace.attr(...) + if isinstance(node.func, ast.Name): + return node.func.id + return node.func.value.id + '.' + node.func.attr + if isinstance(node, ast.FunctionDef): # form def func(...) + return node.name + if isinstance(node, ast.keyword): + return node.arg + try: + if isinstance(node, ast.arg): # form func(..., arg, ...) + return node.arg + except AttributeError: + pass + + raise TypeError('Invalid AST node type: ' + str(type(node))) + + +def getkwarg(node, argname, default=None): + """ Helper function to get AST node of a keyword argument (of form + "argname="). """ + + for kw in node.keywords: + if rname(kw) == argname: + return kw.value + return default + + +def DaCeSyntaxError(visitor, node, err): + """ Reports errors with their corresponding file/line information. """ + + try: + line = node.lineno + col = node.col_offset + except AttributeError: + line = 0 + col = 0 + + return SyntaxError(err + "\n in File " + str(visitor.filename) + + ", line " + str(line) + ":" + str(col) + + ", in function " + str(visitor.curnode.name)) + + +def get_tuple(visitor, node): + """ Parses and returns a tuple from an AST node, including explicit None + values. """ + if isinstance(node, ast.Num): # A number (single axis) + return (node.n, ) + if isinstance( + node, + ast.Tuple): # Tuple. Assumes all axes are values. Form: (2,3,5) + for v in node.elts: + if not isinstance(v, ast.Num): + raise DaCeSyntaxError(visitor, node, + 'Axis tuple can only contain integers') + return tuple(value.n for value in node.elts) + if isinstance(node, ast.Name): # Python 2.x variant of explicit "None" + if node.id == 'None': + return None + if node is None: # Argument not given + return None + + # Python 3 only + try: + if isinstance(node, ast.NameConstant + ): # Explicit None (or True). Example: "axis=None" + if node.value is None: + return node.value + except AttributeError: + pass + + raise DaCeSyntaxError( + visitor, node, + 'Invalid expression, expected tuple of constant numbers or None') + + +def subscript_to_ast_slice(node, without_array=False): + """ Converts an AST subscript to slice on the form + (, [<3-tuples of AST nodes>]). If an ast.Name is passed, returns + (name, None), implying the full range. + @param node: The AST node to convert. + @param without_array: If True, returns only the slice. Otherwise, + returns a 2-tuple of (array, range). + """ + + result_arr = None + result_slice = None + + if isinstance(node, ast.Name): + # None implies the full array. We can't create artificial + # (None, None, None) tuples, because we don't know the dimensionality of + # the array at this point + result_arr, result_slice = node.id, None + return result_slice if without_array else (result_arr, result_slice) + + if not isinstance(node, ast.Subscript): + raise TypeError('AST node is not a subscript') + + # ND Index + if isinstance(node.slice, ast.Index): + if isinstance(node.slice.value, ast.Tuple): + result_slice = [dim for dim in node.slice.value.elts] + else: + result_slice = [node.slice.value] + # 1D slice + elif isinstance(node.slice, ast.Slice): + result_slice = [(node.slice.lower, node.slice.upper, node.slice.step)] + else: # ND slice + result_slice = [] + + for d in node.slice.dims: + if isinstance(d, ast.Index): + result_slice.append(d.value) + else: + result_slice.append((d.lower, d.upper, d.step)) + + if without_array: + return result_slice + else: + return (rname(node.value), result_slice) + + +def subscript_to_ast_slice_recursive(node): + """ Converts an AST subscript to a slice in a recursive manner into nested + subscripts. + @see: subscript_to_ast_slice + """ + result = [] + while isinstance(node, ast.Subscript): + result.insert(0, subscript_to_ast_slice(node, True)) + node = node.value + + return result + + +def unparse(node): + """ Unparses an AST node to a Python string, chomping trailing newline. """ + if node is None: + return None + return astunparse.unparse(node).strip() + + +# Helper function to convert an ND subscript AST node to a list of 3-tuple +# slice strings +def subscript_to_slice(node, arrays, without_array=False): + """ Converts an AST subscript to slice on the form + (, [<3-tuples of indices>]). If an ast.Name is passed, return + (name, None), implying the full range. """ + + name, ast_slice = subscript_to_ast_slice(node) + if name in arrays: + arrname = name + else: + arrname = None + + rng = astrange_to_symrange(ast_slice, arrays, arrname) + if without_array: + return rng + else: + return name, rng + + +def astrange_to_symrange(astrange, arrays, arrname=None): + """ Converts an AST range (array, [(start, end, skip)]) to a symbolic math + range, using the obtained array sizes and resolved symbols. """ + if arrname is not None: + arrdesc = arrays[arrname] + + # If the array is a scalar, return None + if arrdesc.shape is None: + return None + + # If range is the entire array, use the array descriptor to obtain the + # entire range + if astrange is None: + return [ + (symbolic.pystr_to_symbolic(0), + symbolic.pystr_to_symbolic(types.symbol_name_or_value(s)) - 1, + symbolic.pystr_to_symbolic(1)) for s in arrdesc.shape + ] + + result = [None] * len(astrange) + for i, r in enumerate(astrange): + if isinstance(r, tuple): + begin, end, skip = r + # Default values + if begin is None: + begin = symbolic.pystr_to_symbolic(0) + else: + begin = symbolic.pystr_to_symbolic(unparse(begin)) + if end is None and arrname is None: + raise SyntaxError('Cannot define range without end') + elif end is not None: + end = symbolic.pystr_to_symbolic(unparse(end)) - 1 + else: + end = symbolic.pystr_to_symbolic( + types.symbol_name_or_value(arrdesc.shape[i])) - 1 + if skip is None: + skip = symbolic.pystr_to_symbolic(1) + else: + skip = symbolic.pystr_to_symbolic(unparse(skip)) + else: + # In the case where a single element is given + begin = symbolic.pystr_to_symbolic(unparse(r)) + end = begin + skip = symbolic.pystr_to_symbolic(1) + + result[i] = (begin, end, skip) + + return result + + +def negate_expr(node): + """ Negates an AST expression by adding a `Not` AST node in front of it. + """ + if hasattr(node, "__len__"): + if len(node) > 1: + raise ValueError("negate_expr only expects " + "single expressions, got: {}".format(node)) + expr = node[0] + else: + expr = node + newexpr = ast.Expr(value=ast.UnaryOp(op=ast.Not(), operand=expr)) + newexpr = ast.copy_location(newexpr, expr) + return ast.fix_missing_locations(newexpr) + + +class ExtNodeTransformer(ast.NodeTransformer): + """ A `NodeTransformer` subclass that walks the abstract syntax tree and + allows modification of nodes. As opposed to `NodeTransformer`, + this class is capable of traversing over top-level expressions in + bodies in order to discern DaCe statements from others. + """ + + # Default implementation of TopLevelExpr + def visit_TopLevelExpr(self, node): + return self.visit(node) + + def generic_visit(self, node): + for field, old_value in ast.iter_fields(node): + if isinstance(old_value, list): + new_values = [] + for value in old_value: + if isinstance(value, ast.AST): + if (field == 'body' + or field == 'orelse') and isinstance( + value, ast.Expr): + value = self.visit_TopLevelExpr(value) + else: + value = self.visit(value) + if value is None: + continue + elif not isinstance(value, ast.AST): + new_values.extend(value) + continue + new_values.append(value) + old_value[:] = new_values + elif isinstance(old_value, ast.AST): + new_node = self.visit(old_value) + if new_node is None: + delattr(node, field) + else: + setattr(node, field, new_node) + return node + + +class ExtNodeVisitor(ast.NodeVisitor): + """ A `NodeVisitor` subclass that walks the abstract syntax tree. + As opposed to `NodeVisitor`, this class is capable of traversing over + top-level expressions in bodies in order to discern DaCe statements + from others. """ + + def visit_TopLevel(self, node): + clsname = type(node).__name__ + if getattr(self, "visit_TopLevel" + clsname, False): + getattr(self, "visit_TopLevel" + clsname)(node) + else: + self.visit(node) + + def generic_visit(self, node): + for field, old_value in ast.iter_fields(node): + if isinstance(old_value, list): + for value in old_value: + if isinstance(value, ast.AST): + if (field == 'body' or field == 'orelse'): + clsname = type(value).__name__ + if getattr(self, "visit_TopLevel" + clsname, + False): + getattr(self, + "visit_TopLevel" + clsname)(value) + else: + self.visit(value) + else: + self.visit(value) + elif isinstance(old_value, ast.AST): + self.visit(old_value) + return node diff --git a/dace/frontend/python/decorators.py b/dace/frontend/python/decorators.py new file mode 100644 index 0000000000..f6e980446d --- /dev/null +++ b/dace/frontend/python/decorators.py @@ -0,0 +1,109 @@ +""" Python decorators for DaCe functions. """ + +from __future__ import print_function +from functools import wraps + +from dace import types +from dace.frontend.python import parser + + +def paramdec(dec): + """ Parameterized decorator meta-decorator. Enables using `@decorator`, + `@decorator()`, and `@decorator(...)` with the same function. """ + + @wraps(dec) + def layer(*args, **kwargs): + + # Allows the use of @decorator, @decorator(), and @decorator(...) + if len(kwargs) == 0 and len(args) == 1 and callable( + args[0]) and not isinstance(args[0], types.typeclass): + return dec(*args, **kwargs) + + @wraps(dec) + def repl(f): + return dec(f, *args, **kwargs) + + return repl + + return layer + + +############################################# + + +@paramdec +def program(f, *args, **kwargs): + """ DaCe program, entry point to a data-centric program. """ + + # Parses a python @dace.program function and returns an object that can + # be translated + return parser.DaceProgram(f, args, kwargs) + + +############################################# + + +@paramdec +def external_function(f, **alternative_implementations): + """ External functions that may be called within a DaCe program. """ + return types._external_function(f, alternative_implementations) + + +# Internal DaCe decorators, these are not actually run, but rewritten + + +# Dataflow constructs +@paramdec +def map(f, rng): + """ A Map is representation of parallel execution, containing + an integer set (Python range) for which its contents are run + concurrently. + @param rng: The map's range. + """ + return None + + +@paramdec +def consume(f, stream, pes): + """ Consume is a scope, like `Map`, that creates parallel execution. + Unlike `Map`, it creates a producer-consumer relationship between an + input stream and the contents. The contents are run by the given number + of processing elements, who will try to pop elements from the input + stream until a given quiescence condition is reached. + @param stream: The stream to pop from. + @param pes: The number of processing elements to use. + """ + return None + + +def tasklet(f): + """ A general procedure that cannot access any memory apart from incoming + and outgoing memlets. The DaCe framework cannot analyze these tasklets + for optimization. """ + return None + + +# Control-flow constructs +@paramdec +def iterate(f, rng): + """ A decorator version of a for loop, with a range of `rng`. + @param rng: The range of the for loop. + """ + return None + + +@paramdec +def loop(f, cond): + """ A decorator version of a while loop, with a looping condition `cond`. + @param cond: The condition of the while loop. + """ + return None + + +@paramdec +def conditional(f, cond): + """ A decorator version of conditional execution, with an if-condition + `cond`. + @param cond: The condition of the branch. + """ + return None diff --git a/dace/frontend/python/depanalysis.py b/dace/frontend/python/depanalysis.py new file mode 100644 index 0000000000..a7c4e880d6 --- /dev/null +++ b/dace/frontend/python/depanalysis.py @@ -0,0 +1,796 @@ +""" Data dependency analysis functionality, as well as functions to convert + an AST-parsed data-centric Python program into an SDFG. """ +import ast +from collections import deque, OrderedDict +from copy import deepcopy as dcpy +import sympy + +from dace import data as dt, types, symbolic +from dace.graph import edges as ed +from dace.graph import nodes as nd +from dace import subsets as sbs +from dace import sdfg +from dace.memlet import EmptyMemlet, Memlet +from dace.frontend.python import astnodes, astutils +from dace.frontend.python.astparser import MemletRemover, ModuleInliner + + +def create_states_simple(pdp, + out_sdfg, + start_state=None, + end_state=None, + start_edge=None): + """ Creates a state per primitive, with the knowledge that they can be + optimized later. + @param pdp: A parsed dace program. + @param out_sdfg: The output SDFG. + @param start_state: The starting/parent state to connect from (for + recursive calls). + @param end_state: The end/parent state to connect to (for + recursive calls). + @return: A dictionary mapping between a state and the list of dace + primitives included in it. + """ + state_to_primitives = OrderedDict() + + # Create starting state and edge + if start_state is None: + start_state = out_sdfg.add_state('start') + state_to_primitives[start_state] = [] + if start_edge is None: + start_edge = ed.InterstateEdge() + + previous_state = start_state + previous_edge = start_edge + + for i, primitive in enumerate(pdp.children): + state = out_sdfg.add_state(primitive.name) + state_to_primitives[state] = [] + # Edge that can be created on entry to control flow children + entry_edge = None + + ######################################### + # Cases depending on primitive type + ######################################### + + # Nothing special happens with a dataflow node (nested states are + # handled with a separate call to create_states_simple) + if isinstance(primitive, astnodes._DataFlowNode): + out_sdfg.add_edge(previous_state, state, previous_edge) + state_to_primitives[state] = [primitive] + previous_state = state + previous_edge = ed.InterstateEdge() + + # Control flow needs to traverse into children nodes + elif isinstance(primitive, astnodes._ControlFlowNode): + # Iteration has >=3 states - begin, loop[...], end; and connects the + # loop states, as well as the begin to end directly if the condition + # did not evaluate to true + if isinstance(primitive, astnodes._IterateNode): + + condition = ast.parse( + '(%s %s %s)' % (primitive.params[0], '<' + if primitive.range[0][2] >= 0 else '>', + primitive.range[0][1] + 1)).body[0] + condition_neg = astutils.negate_expr(condition) + + # Loop-start state + lstart_state = out_sdfg.add_state(primitive.name + '_start') + state_to_primitives[lstart_state] = [] + out_sdfg.add_edge(previous_state, lstart_state, previous_edge) + out_sdfg.add_edge( + lstart_state, + state, + ed.InterstateEdge( + assignments={ + primitive.params[0]: primitive.range[0][0] + })) + + # Loop-end state that jumps back to `state` + loop_state = out_sdfg.add_state(primitive.name + '_end') + state_to_primitives[loop_state] = [] + # Connect loop + out_sdfg.add_edge( + loop_state, + state, + ed.InterstateEdge( + assignments={ + primitive.params[0]: + symbolic.pystr_to_symbolic(primitive.params[0]) + + primitive.range[0][2] + })) + + # End connection + previous_state = state + previous_edge = ed.InterstateEdge(condition=condition_neg) + + # Create children states + cmap = create_states_simple( + primitive, + out_sdfg, + state, + loop_state, + ed.InterstateEdge(condition=condition)) + state_to_primitives.update(cmap) + + # Loop is similar to iterate, but more general w.r.t. conditions + elif isinstance(primitive, astnodes._LoopNode): + loop_condition = primitive.condition + + # Entry + out_sdfg.add_edge(previous_state, state, previous_edge) + + # Loop-end state that jumps back to `state` + loop_state = out_sdfg.add_state(primitive.name + '_end') + state_to_primitives[loop_state] = [] + + # Loopback + out_sdfg.add_edge(loop_state, state, ed.InterstateEdge()) + # End connection + previous_state = state + previous_edge = ed.InterstateEdge( + condition=astutils.negate_expr(loop_condition)) + entry_edge = ed.InterstateEdge(condition=loop_condition) + + # Create children states + cmap = create_states_simple(primitive, out_sdfg, state, + loop_state, entry_edge) + state_to_primitives.update(cmap) + + elif isinstance(primitive, astnodes._IfNode): + if_condition = primitive.condition + # Check if we have an else node, otherwise add a skip condition + # ourselves + if (i + 1) < len(pdp.children) and isinstance( + pdp.children[i + 1], astnodes._ElseNode): + has_else = True + else_prim = pdp.children[i + 1] + else_condition = else_prim.condition + else: + has_else = False + else_condition = astutils.negate_expr(primitive.condition) + + # End-of-branch state (converge to this) + bend_state = out_sdfg.add_state(primitive.name + '_end') + state_to_primitives[bend_state] = [] + + # Entry + out_sdfg.add_edge(previous_state, state, previous_edge) + + # Create children states + cmap = create_states_simple( + primitive, + out_sdfg, + state, + bend_state, + ed.InterstateEdge(condition=if_condition)) + state_to_primitives.update(cmap) + + # Handle 'else' condition + if not has_else: + out_sdfg.add_edge( + state, + bend_state, + ed.InterstateEdge(condition=else_condition)) + else: + # Recursively parse 'else' primitive's children + cmap = create_states_simple( + else_prim, + out_sdfg, + state, + bend_state, + ed.InterstateEdge(condition=else_condition)) + state_to_primitives.update(cmap) + + # Exit + previous_state = bend_state + previous_edge = ed.InterstateEdge() + + elif isinstance(primitive, astnodes._ElseNode): + if i - 1 < 0 or not isinstance(pdp.children[i - 1], + astnodes._IfNode): + raise SyntaxError('Found else state without matching if') + + # If 'else' state is correct, we already processed it + del state_to_primitives[state] + out_sdfg.remove_node(state) + + # Connect to end_state (and create it if necessary) + if end_state is None: + end_state = out_sdfg.add_state('end') + state_to_primitives[end_state] = [] + out_sdfg.add_edge(previous_state, end_state, previous_edge) + + return state_to_primitives + + +def _make_full_range(memlet: astnodes._Memlet): + fullRange = sbs.Range([(0, s - 1, 1) for s in memlet.data.shape]) + fullMemlet = astnodes._Memlet(memlet.data, + memlet.dataname, memlet.attribute, + fullRange.num_elements(), None, None, + fullRange, memlet.veclen, None, None, {}) + return fullMemlet + + +def _full_memlet_from_array(arrayname, array): + fullRange = sbs.Range([(0, s - 1, 1) for s in array.shape]) + fullMemlet = astnodes._Memlet(array, arrayname, None, + fullRange.num_elements(), None, None, + fullRange, 1, None, None, {}) + return fullMemlet + + +def inherit_dependencies(prim): + + # Inject tasklets for map nodes and push down dependencies + if (isinstance(prim, (astnodes._MapNode, astnodes._ConsumeNode)) + and len(prim.children) == 0): + tasklet = astnodes._TaskletNode(prim.name, prim.ast) + tasklet.parent = prim + tasklet.inputs = OrderedDict( + [(k, v) for k, v in prim.inputs.items() if '__DACEIN_' not in k]) + tasklet.outputs = OrderedDict( + [(k, v) for k, v in prim.outputs.items() if '__DACEIN_' not in k]) + prim.inputs = OrderedDict( + [(k, v) for k, v in prim.inputs.items() if '__DACEIN_' in k]) + prim.outputs = OrderedDict( + [(k, v) for k, v in prim.outputs.items() if '__DACEIN_' in k]) + prim.children.append(tasklet) + + # The recursive dependencies of this node which we will return + dependIn = OrderedDict() + dependOut = OrderedDict() + + # Add own dependencies (input) + inputQueue = deque(prim.inputs.items()) + while len(inputQueue) > 0: + arrname, memlet = inputQueue.popleft() + fullMemlet = _make_full_range(memlet) + dependIn[fullMemlet.dataname] = fullMemlet + # Additional dependencies (e.g., as a result of indirection) + for aname, additional_arr in memlet.otherdeps.items(): + additional_astmemlet = _full_memlet_from_array( + aname, additional_arr) + dependIn[additional_astmemlet.dataname] = additional_astmemlet + + # Add own dependencies (output) + outputQueue = deque(prim.outputs.items()) + while len(outputQueue) > 0: + arrname, memlet = outputQueue.popleft() + fullMemlet = _make_full_range(memlet) + dependOut[fullMemlet.dataname] = fullMemlet + if isinstance(memlet.subset, astnodes._Memlet): + outputQueue.push(memlet.subset) + + # Add recursively inherited dependencies + inheritIn = OrderedDict() + inheritOut = OrderedDict() + arrs = prim.transients.keys() + for child in prim.children: + childIn, childOut = inherit_dependencies(child) + # Only inherit dependencies from arrays defined in this scope + inheritIn.update( + OrderedDict([(k, v) for k, v in childIn.items() if k not in arrs])) + inheritOut.update( + OrderedDict( + [(k, v) for k, v in childOut.items() if k not in arrs])) + + # We should not overwrite an explicit dependency with an inherited one: + # this is most likely a programming mistake + for key in inheritIn.keys(): + if key in prim.inputs: + raise ValueError("Inherited dependency from '" + child.name + + "' overwrites explicit dependency in '" + + prim.name + "' (" + str(prim.inputs[key]) + ")") + for key in inheritOut.keys(): + if key in prim.outputs: + raise ValueError("Inherited dependency from '" + child.name + + "' overwrites explicit dependency in '" + + prim.name + "' (" + str(prim.outputs[key]) + ")") + prim.inputs.update(inheritIn) + prim.outputs.update(inheritOut) + dependIn.update(dcpy(inheritIn)) + dependOut.update(dcpy(inheritOut)) + + if isinstance(prim, astnodes._ControlFlowNode): + # Don't inherit dependencies across control flow boundaries + return OrderedDict(), OrderedDict() + else: + return dependIn, dependOut + + +def _subset_has_indirection(subset): + for dim in subset: + if not isinstance(dim, tuple): + dim = [dim] + for r in dim: + if symbolic.contains_sympy_functions(r): + return True + return False + + +def _add_astmemlet_edge(sdfg, + state, + src_node, + src_conn, + dst_node, + dst_conn, + ast_memlet, + data=None, + wcr=None, + wcr_identity=None): + try: + if src_node.data == dst_node.data: + raise RuntimeError("Added edge connection data nodes " + "with same descriptor: {} to {}".format( + src_node, dst_node)) + except AttributeError: + pass + if _subset_has_indirection(ast_memlet.subset): + add_indirection_subgraph(sdfg, state, src_node, dst_node, ast_memlet) + return + + if data is not None: + raise NotImplementedError('This should never happen') + + memlet = Memlet(ast_memlet.dataname, ast_memlet.num_accesses, + ast_memlet.subset, ast_memlet.veclen, wcr, wcr_identity) + state.add_edge(src_node, src_conn, dst_node, dst_conn, memlet) + + +def _get_input_symbols(inputs, freesyms): + syminputs = set( + str(i)[9:] for i in inputs.keys() if str(i).startswith('__DACEIN_')) + return freesyms & syminputs + + +# TODO: The following two functions can be replaced with better dataflow +# generation procedures + + +def input_node_for_array(state, data: str): + # If the node appears as one of the source nodes, return it first + for n in state.source_nodes(): + if isinstance(n, nd.AccessNode): + if n.data == data: + return n + # Otherwise, if the node is located elsewhere, return it + for n in state.nodes(): + if isinstance(n, nd.AccessNode): + if n.data == data: + return n + + return nd.AccessNode(data) + + +def output_node_for_array(state, data: str): + for n in state.sink_nodes(): + if isinstance(n, nd.AccessNode): + if n.data == data: + return n + + return nd.AccessNode(data) + + +def _build_dataflow_graph_recurse(sdfg, state, primitives, modules, superEntry, + super_exit): + # Array of pairs (exit node, memlet) + exit_nodes = [] + + if len(primitives) == 0: + # Inject empty tasklets into empty states + primitives = [astnodes._EmptyTaskletNode("Empty Tasklet", None)] + + for prim in primitives: + label = prim.name + + # Expand node to get entry and exit points + if isinstance(prim, astnodes._MapNode): + if len(prim.children) == 0: + raise ValueError("Map node expected to have children") + mapNode = nd.Map( + label, prim.params, prim.range, is_async=prim.is_async) + # Add connectors for inputs that exist as array nodes + entry = nd.MapEntry( + mapNode, + _get_input_symbols(prim.inputs, prim.range.free_symbols)) + exit = nd.MapExit(mapNode) + elif isinstance(prim, astnodes._ConsumeNode): + if len(prim.children) == 0: + raise ValueError("Consume node expected to have children") + consumeNode = nd.Consume(label, (prim.params[1], prim.num_pes), + prim.condition) + entry = nd.ConsumeEntry(consumeNode) + exit = nd.ConsumeExit(consumeNode) + elif isinstance(prim, astnodes._ReduceNode): + rednode = nd.Reduce(prim.ast, prim.axes, prim.identity) + state.add_node(rednode) + entry = rednode + exit = rednode + elif isinstance(prim, astnodes._TaskletNode): + if isinstance(prim, astnodes._EmptyTaskletNode): + tasklet = nd.EmptyTasklet(prim.name) + else: + # Remove memlets from tasklet AST + if prim.language == types.Language.Python: + clean_code = MemletRemover().visit(prim.ast) + clean_code = ModuleInliner(modules).visit(clean_code) + else: # Use external code from tasklet definition + if prim.extcode is None: + raise SyntaxError("Cannot define an intrinsic " + "tasklet without an implementation") + clean_code = prim.extcode + tasklet = nd.Tasklet( + prim.name, + set(prim.inputs.keys()), + set(prim.outputs.keys()), + code=clean_code, + language=prim.language, + code_global=prim.gcode) # TODO: location=prim.location + + # Need to add the tasklet in case we're in an empty state, where no + # edge will be drawn to it + state.add_node(tasklet) + entry = tasklet + exit = tasklet + + elif isinstance(prim, astnodes._NestedSDFGNode): + prim.sdfg.parent = state + prim.sdfg._parent_sdfg = sdfg + prim.sdfg.update_sdfg_list([]) + nsdfg = nd.NestedSDFG(prim.name, prim.sdfg, + set(prim.inputs.keys()), + set(prim.outputs.keys())) + state.add_node(nsdfg) + entry = nsdfg + exit = nsdfg + + elif isinstance(prim, astnodes._ProgramNode): + return + elif isinstance(prim, astnodes._ControlFlowNode): + continue + else: + raise TypeError("Node type not implemented: " + + str(prim.__class__)) + + # Add incoming edges + for varname, memlet in prim.inputs.items(): + arr = memlet.dataname + if (prim.parent is not None + and memlet.dataname in prim.parent.transients.keys()): + node = input_node_for_array(state, memlet.dataname) + + # Add incoming edge into transient as well + # FIXME: A bit hacked? + if arr in prim.parent.inputs: + astmem = prim.parent.inputs[arr] + _add_astmemlet_edge(sdfg, state, superEntry, None, node, + None, astmem) + + # Remove local name from incoming edge to parent + prim.parent.inputs[arr].local_name = None + elif superEntry: + node = superEntry + else: + node = input_node_for_array(state, memlet.dataname) + + # Destination connector inference + # Connected to a tasklet or a nested SDFG + dst_conn = (memlet.local_name + if isinstance(entry, nd.CodeNode) else None) + # Connected to a scope as part of its range + if str(varname).startswith('__DACEIN_'): + dst_conn = str(varname)[9:] + # Handle special case of consume input stream + if (isinstance(entry, nd.ConsumeEntry) + and memlet.data == prim.stream): + dst_conn = 'IN_stream' + + # If a memlet that covers this input already exists, skip + # generating this one; otherwise replace memlet with ours + skip_incoming_edge = False + remove_edge = None + for e in state.edges_between(node, entry): + if e.data.data != memlet.dataname or dst_conn != e.dst_conn: + continue + if e.data.subset.covers(memlet.subset): + skip_incoming_edge = True + break + elif memlet.subset.covers(e.data.subset): + remove_edge = e + break + else: + print('WARNING: Performing bounding-box union on', + memlet.subset, 'and', e.data.subset, '(in)') + e.data.subset = sbs.bounding_box_union( + e.data.subset, memlet.subset) + e.data.num_accesses += memlet.num_accesses + skip_incoming_edge = True + break + + if remove_edge is not None: + state.remove_edge(remove_edge) + + if skip_incoming_edge == False: + _add_astmemlet_edge(sdfg, state, node, None, entry, dst_conn, + memlet) + + # If there are no inputs, generate a dummy edge + if superEntry and len(prim.inputs) == 0: + state.add_edge(superEntry, None, entry, None, EmptyMemlet()) + + if len(prim.children) > 0: + # Recurse + inner_outputs = _build_dataflow_graph_recurse( + sdfg, state, prim.children, modules, entry, exit) + # Infer output node for each memlet + for i, (out_src, mem) in enumerate(inner_outputs): + # If there is no such array in this primitive's outputs, + # it's an external array (e.g., a map in a map). In this case, + # connect to the exit node + if mem.dataname in prim.outputs: + inner_outputs[i] = (out_src, prim.outputs[mem.dataname]) + else: + inner_outputs[i] = (out_src, mem) + else: + inner_outputs = [(exit, mem) for mem in prim.outputs.values()] + + # Add outgoing edges + for out_src, astmem in inner_outputs: + + data = astmem.data + dataname = astmem.dataname + + # If WCR is not none, it needs to be handled in the code. Check for + # this after, as we only expect it for one distinct case + wcr_was_handled = astmem.wcr is None + + # TODO: This is convoluted. We should find a more readable + # way of connecting the outgoing edges. + + if super_exit is None: + + # Assert that we're in a top-level node + if ((not isinstance(prim.parent, astnodes._ProgramNode)) and + (not isinstance(prim.parent, astnodes._ControlFlowNode))): + raise RuntimeError("Expected to be at the top node") + + # Looks hacky + src_conn = (astmem.local_name if isinstance( + out_src, (nd.Tasklet, nd.NestedSDFG)) else None) + + # Here we just need to connect memlets directly to their + # respective data nodes + out_tgt = output_node_for_array(state, astmem.dataname) + + # If a memlet that covers this outuput already exists, skip + # generating this one; otherwise replace memlet with ours + skip_outgoing_edge = False + remove_edge = None + for e in state.edges_between(out_src, out_tgt): + if e.data.data != astmem.dataname or src_conn != e.src_conn: + continue + if e.data.subset.covers(astmem.subset): + skip_outgoing_edge = True + break + elif astmem.subset.covers(e.data.subset): + remove_edge = e + break + else: + print('WARNING: Performing bounding-box union on', + astmem.subset, 'and', e.data.subset, '(out)') + e.data.subset = sbs.bounding_box_union( + e.data.subset, astmem.subset) + e.data.num_accesses += astmem.num_accesses + skip_outgoing_edge = True + break + + if skip_outgoing_edge == True: + continue + if remove_edge is not None: + state.remove_edge(remove_edge) + + _add_astmemlet_edge( + sdfg, + state, + out_src, + src_conn, + out_tgt, + None, + astmem, + wcr=astmem.wcr, + wcr_identity=astmem.wcr_identity) + wcr_was_handled = (True if astmem.wcr is not None else + wcr_was_handled) + + # If the program defines another output, connect it too. + # This refers to the case where we have streams, which + # must define an input and output, and sometimes this output + # is defined in pdp.outputs + if (isinstance(out_tgt, nd.AccessNode) + and isinstance(out_tgt.desc(sdfg), dt.Stream)): + try: + stream_memlet = next( + v for k, v in prim.parent.outputs.items() + if k == out_tgt.data) + stream_output = output_node_for_array( + state, stream_memlet.dataname) + _add_astmemlet_edge(sdfg, state, out_tgt, None, + stream_output, None, stream_memlet) + except StopIteration: # Stream output not found, skip + pass + + else: # We're in a nest + + if isinstance(prim, astnodes._ScopeNode): + # We're a map or a consume node, that needs to connect our + # exit to either an array or to the super_exit + if data.transient and dataname in prim.parent.transients: + # Connect the exit directly + out_tgt = output_node_for_array(state, data.dataname) + _add_astmemlet_edge(sdfg, state, out_src, None, + out_tgt, None, astmem) + else: + # This is either a transient defined in an outer scope, + # or an I/O array, so redirect thruogh the exit node + _add_astmemlet_edge(sdfg, state, out_src, None, + super_exit, None, astmem) + # Instruct outer recursion layer to continue the route + exit_nodes.append((super_exit, astmem)) + elif isinstance( + prim, + (astnodes._TaskletNode, astnodes._NestedSDFGNode)): + # We're a tasklet, and need to connect either to the exit + # if the array is I/O or is defined in a scope further out, + # or directly to the transient if it's defined locally + if dataname in prim.parent.transients: + # This is a local transient variable, so connect to it + # directly + out_tgt = output_node_for_array(state, data.dataname) + _add_astmemlet_edge(sdfg, state, out_src, + astmem.local_name, out_tgt, None, + astmem) + else: + # This is an I/O array, or an outer level transient, so + # redirect through the exit node + _add_astmemlet_edge( + sdfg, + state, + out_src, + astmem.local_name, + super_exit, + None, + astmem, + wcr=astmem.wcr, + wcr_identity=astmem.wcr_identity) + exit_nodes.append((super_exit, astmem)) + if astmem.wcr is not None: + wcr_was_handled = True # Sanity check + else: + raise TypeError("Unexpected node type: {}".format( + type(out_src).__name__)) + + if not wcr_was_handled and not isinstance(prim, + astnodes._ScopeNode): + raise RuntimeError("Detected unhandled WCR for primitive '{}' " + "of type {}. WCR is only expected for " + "tasklets in a map/consume scope.".format( + prim.name, + type(prim).__name__)) + + return exit_nodes + + +def build_dataflow_graph(sdfg, state, primitives, modules): + _build_dataflow_graph_recurse(sdfg, state, primitives, modules, None, None) + + +def add_indirection_subgraph(sdfg, graph, src, dst, memlet): + """ Replaces the specified edge in the specified graph with a subgraph that + implements indirection without nested AST memlet objects. """ + if not isinstance(memlet, astnodes._Memlet): + raise TypeError("Expected memlet to be astnodes._Memlet") + + indirect_inputs = set() + indirect_outputs = set() + + # Scheme for multi-array indirection: + # 1. look for all arrays and accesses, create set of arrays+indices + # from which the index memlets will be constructed from + # 2. each separate array creates a memlet, of which num_accesses = len(set) + # 3. one indirection tasklet receives them all + original array and + # produces the right output index/range memlet + ######################### + # Step 1 + accesses = OrderedDict() + newsubset = dcpy(memlet.subset) + for dimidx, dim in enumerate(memlet.subset): + # Range/Index disambiguation + direct_assignment = False + if not isinstance(dim, tuple): + dim = [dim] + direct_assignment = True + + for i, r in enumerate(dim): + for expr in sympy.preorder_traversal(r): + if symbolic.is_sympy_userfunction(expr): + fname = expr.func.__name__ + if fname not in accesses: + accesses[fname] = [] + + # Replace function with symbol (memlet local name to-be) + if expr.args in accesses[fname]: + aindex = accesses[fname].index(expr.args) + toreplace = 'index_' + fname + '_' + str(aindex) + else: + accesses[fname].append(expr.args) + toreplace = 'index_' + fname + '_' + str( + len(accesses[fname]) - 1) + + if direct_assignment: + newsubset[dimidx] = r.subs(expr, toreplace) + else: + newsubset[dimidx][i] = r.subs(expr, toreplace) + ######################### + # Step 2 + ind_inputs = {'__ind_' + memlet.local_name} + ind_outputs = {'lookup'} + # Add accesses to inputs + for arrname, arr_accesses in accesses.items(): + for i in range(len(arr_accesses)): + ind_inputs.add('index_%s_%d' % (arrname, i)) + + tasklet = nd.Tasklet("Indirection", ind_inputs, ind_outputs) + + input_index_memlets = [] + for arrname, arr_accesses in accesses.items(): + arr = memlet.otherdeps[arrname] + for i, access in enumerate(arr_accesses): + # Memlet to load the indirection index + indexMemlet = Memlet(arrname, 1, sbs.Indices(list(access)), 1) + input_index_memlets.append(indexMemlet) + graph.add_edge(src, None, tasklet, "index_%s_%d" % (arrname, i), + indexMemlet) + + ######################### + # Step 3 + # Create new tasklet that will perform the indirection + indirection_ast = ast.parse("lookup = {arr}[{index}]".format( + arr='__ind_' + memlet.local_name, + index=', '.join([symbolic.symstr(s) for s in newsubset]))) + # Conserve line number of original indirection code + tasklet.code = ast.copy_location(indirection_ast.body[0], memlet.ast) + + # Create transient variable to trigger the indirected load + if memlet.num_accesses == 1: + storage = sdfg.add_scalar( + '__' + memlet.local_name + '_value', + memlet.data.dtype, + transient=True) + else: + storage = sdfg.add_array( + '__' + memlet.local_name + '_value', + memlet.data.dtype, + storage=types.StorageType.Default, + transient=True, + shape=memlet.bounding_box_size()) + indirectRange = sbs.Range([(0, s - 1, 1) for s in storage.shape]) + dataNode = nd.AccessNode('__' + memlet.local_name + '_value') + + # Create memlet that depends on the full array that we look up in + fullRange = sbs.Range([(0, s - 1, 1) for s in memlet.data.shape]) + fullMemlet = Memlet(memlet.dataname, memlet.num_accesses, fullRange, + memlet.veclen) + graph.add_edge(src, None, tasklet, '__ind_' + memlet.local_name, + fullMemlet) + + # Memlet to store the final value into the transient, and to load it into + # the tasklet that needs it + indirectMemlet = Memlet('__' + memlet.local_name + '_value', + memlet.num_accesses, indirectRange, memlet.veclen) + graph.add_edge(tasklet, 'lookup', dataNode, None, indirectMemlet) + + valueMemlet = Memlet('__' + memlet.local_name + '_value', + memlet.num_accesses, indirectRange, memlet.veclen) + graph.add_edge(dataNode, None, dst, memlet.local_name, valueMemlet) diff --git a/dace/frontend/python/ndarray.py b/dace/frontend/python/ndarray.py new file mode 100644 index 0000000000..792e5c961c --- /dev/null +++ b/dace/frontend/python/ndarray.py @@ -0,0 +1,187 @@ +""" Array types and wrappers used in DaCe's Python frontend. """ +from __future__ import print_function +import ctypes +import enum +import inspect +import numpy +import itertools +from collections import deque + +from dace import symbolic, types + +########################################################### +# NDArray type + + +class ndarray(numpy.ndarray): + """ An N-dimensional array wrapper around `numpy.ndarray` that enables + symbolic sizes. """ + + def __new__(cls, + shape, + dtype=types.float32, + materialize_func=None, + allow_conflicts=False, + *args, + **kwargs): + """ Initializes a DaCe ND-array. + @param shape: The array shape (may contain symbols). + @param dtype: The array data type. + @param materialize_func: An optional string that contains a method + to materialize array contents on demand. + If not None, the array is not allocated + within the DaCe program. + @param allow_conflicts: If True, suppresses warnings on conflicting + array writes in DaCe programs without a + matching conflict resolution memlet. + """ + # Avoiding import loops + from dace import data + + tmpshape = shape + shape = [symbolic.eval(s, 0) for s in shape] + + kwargs.update({'dtype': dtype.type}) + + res = numpy.ndarray.__new__(cls, shape, *args, **kwargs) + res._symlist = symbolic.symlist(tmpshape) + for _, sym in res._symlist.items(): + sym._arrays_to_update.append(res) + + if not isinstance(dtype, types.typeclass): + dtype = types.typeclass(dtype.type) + + res.descriptor = data.Array( + dtype, + tmpshape, + materialize_func=materialize_func, + transient=False, + allow_conflicts=allow_conflicts) + return res + + def update_resolved_symbol(self, sym): + """ Notifies an array that a symbol has been resolved so that it + can be resized. """ + self.resize( + [symbolic.eval(s, 0) for s in self.descriptor.shape], + refcheck=False) + self._symlist = symbolic.symlist(self.descriptor.shape) + + def missing_syms(self): + return ','.join( + [s for s, v in self._symlist.items() if not v.is_initialized()]) + + def __setitem__(self, key, value): + if self.descriptor.materialize_func is not None: + raise PermissionError( + "You cannot write into an Immaterial storage.") + return numpy.ndarray.__setitem__(self, key, value) + + def __getitem__(self, key): + if 0 in self.shape: + self.update_resolved_symbol(None) + if 0 in self.shape: + raise IndexError( + 'Cannot create sub-array, not all symbols are set " "(missing symbols: %s)' + % self.missing_syms()) + return numpy.ndarray.__getitem__(self, key) + + # Python 2.x compatibility + def __getslice__(self, *args): + if 0 in self.shape: + raise IndexError( + 'Cannot create sub-array, not all symbols are set (missing symbols: %s)' + % self.missing_syms()) + return numpy.ndarray.__getslice__(self, *args) + + def __array_finalize__(self, obj): + if obj is None: + return + from dace import data + + # Create a new descriptor + self.descriptor = data.Array( + types.typeclass(obj.dtype.type), + obj.shape, + materialize_func=None, + transient=False, + allow_conflicts=False) + + self._symlist = {} + + def __lshift__(self, other): + pass + + def __rshift__(self, other): + pass + + def __hash__(self): + return hash(self.data.tobytes()) + + def __call__(self, *args): + return self + + +class transient(ndarray): + """ Transient DaCe array subclass. """ + + def __new__(cls, *args, **kwargs): + res = ndarray.__new__(cls, *args, **kwargs) + res.descriptor.transient = True + return res + + +class stream(object): + """ Stream array object in Python. Mostly used in the Python SDFG + simulator. """ + + def __init__(self, dtype, shape): + from dace import data + + self._type = dtype + self._shape = shape + self.descriptor = data.Stream(dtype, 1, 0, shape, True) + self.queue_array = numpy.ndarray(shape, dtype=deque) + for i in itertools.product(*(range(s) for s in shape)): + self.queue_array[i] = deque() + + @property + def shape(self): + return self.shape + + def __getitem__(self, key): + return self.queue_array.__getitem__(key) + + def __getslice__(self, *args): + return self.queue_array.__getslice__(*args) + + +def scalar(dtype=types.float32, allow_conflicts=False): + """ Convenience function that defines a scalar (array of size 1). """ + return ndarray([1], dtype, allow_conflicts=allow_conflicts) + + +def define_local(dimensions, dtype=types.float32, allow_conflicts=False): + """ Defines a transient array in a DaCe program. """ + return transient(dimensions, dtype=dtype, allow_conflicts=allow_conflicts) + + +def define_local_scalar(dtype=types.float32, allow_conflicts=False): + """ Defines a transient scalar (array of size 1) in a DaCe program. """ + return transient([1], dtype=dtype, allow_conflicts=allow_conflicts) + + +def define_stream(dtype=types.float32, buffer_size=0): + """ Defines a local stream in a DaCe program. """ + return define_streamarray([1], dtype=dtype, buffer_size=buffer_size) + + +def define_streamarray(dimensions, dtype=types.float32, buffer_size=0): + """ Defines a local stream array in a DaCe program. """ + return stream(dtype, dimensions) + + +def asarray(array): + """ Converts an existing Numpy NDArray to DaCe NDArray. """ + obj = numpy.asarray(array).view(ndarray) + return obj diff --git a/dace/frontend/python/ndloop.py b/dace/frontend/python/ndloop.py new file mode 100644 index 0000000000..eed068a9b7 --- /dev/null +++ b/dace/frontend/python/ndloop.py @@ -0,0 +1,63 @@ +""" A single generator that creates an N-dimensional for loop in Python. """ +import itertools +from dace.frontend.python import ndarray + +# Python 3 compatibility for xrange +try: + xxrange = xrange +except NameError: + xxrange = range + + +def slicetoxrange(s): + """ Helper function that turns a slice into a range (for iteration). """ + if isinstance(s, int): + return xxrange(s, s + 1) + + ifnone = lambda a, b: b if a is None else a + + return xxrange(ifnone(s.start, 0), s.stop + 1, ifnone(s.step, 1)) + + +def tupletoxrange(s): + """ Helper function that turns a tuple into a range (for iteration). """ + if isinstance(s, int): + return xxrange(s, s + 1) + + ifnone = lambda a, b: b if a is None else a + ifscalar = lambda a: a[0] if isinstance(a, ndarray.ndarray) else a + allconds = lambda a, b: ifnone(ifscalar(a), b) + + return xxrange(allconds(s[0], 0), ifscalar(s[1]) + 1, allconds(s[2], 1)) + + +def NDLoop(ndslice, internal_function, *args, **kwargs): + """ Wrapped generator that calls an internal function in an N-dimensional + for-loop in Python. + @param ndslice: Slice or list of slices (`slice` objects) to loop over. + @param internal_function: Function to call in loop. + @param *args: Arguments to `internal_function`. + @param **kwargs: Keyword arguments to `internal_function`. + @return: N-dimensional loop index generator. + """ + if isinstance(ndslice, int) or isinstance(ndslice, slice): + ndxrange = (slicetoxrange(ndslice), ) + else: + ndxrange = tuple(slicetoxrange(d) for d in ndslice) + for indices in itertools.product(*ndxrange): + internal_function(*(indices + args), **kwargs) + + +def ndrange(slice_list): + """ Generator that creates an N-dimensional for loop in Python. + @param slice_list: Slice or list of slices (as tuples or `slice`s) + to loop over. + @return: N-dimensional loop index generator. + """ + if not isinstance(slice_list, list): + ndxrange = (tupletoxrange(slice_list), ) + else: + ndxrange = tuple(tupletoxrange(d) for d in slice_list) + + for indices in itertools.product(*ndxrange): + yield indices diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py new file mode 100644 index 0000000000..1410395fae --- /dev/null +++ b/dace/frontend/python/parser.py @@ -0,0 +1,320 @@ +""" DaCe Python parsing functionality and entry point to Python frontend. """ +from __future__ import print_function +from collections import OrderedDict +from functools import wraps +import inspect +import ast +import copy +import sys +import numpy + +from dace import data, symbolic, types +from dace.config import Config +from dace.frontend.python import astparser, astutils, depanalysis +from dace.sdfg import SDFG +from dace.graph import labeling + + +def _create_datadescriptor(obj): + """ Creates a data descriptor from various types of objects. + @see: dace.data.Data + """ + if isinstance(obj, data.Data): + return obj + + try: + return obj.descriptor + except AttributeError: + if isinstance(obj, numpy.ndarray): + return data.Array( + dtype=types.typeclass(obj.dtype.type), shape=obj.shape) + if symbolic.issymbolic(obj): + return data.Scalar(symbolic.symtype(obj)) + if isinstance(obj, types.typeclass): + return data.Scalar(obj) + return data.Scalar(types.typeclass(type(obj))) + + +def _get_type_annotations(f, f_argnames, decorator_args): + """ Obtains types from decorator or from type annotations in a function. + """ + type_annotations = {} + if hasattr(f, '__annotations__'): + type_annotations.update(f.__annotations__) + + # Type annotation conditions + has_args = len(decorator_args) > 0 + has_annotations = len(type_annotations) > 0 + if 'return' in type_annotations: + raise TypeError('DaCe programs do not have a return type') + if has_args and has_annotations: + raise SyntaxError('DaCe programs can only have decorator arguments ' + + '(\'@dace.program(...)\') or type annotations ' + + '(\'def program(arr: type, ...)\'), but not both') + + # Alert if there are any discrepancies between annotations and arguments + if has_args: + # Make sure all arguments are annotated + if len(decorator_args) != len(f_argnames): + raise SyntaxError( + 'Decorator arguments must match number of DaCe ' + + 'program parameters (expecting ' + str(len(f_argnames)) + ')') + # Return arguments and their matched decorator annotation + return { + k: _create_datadescriptor(v) + for k, v in zip(f_argnames, decorator_args) + } + elif has_annotations: + # Make sure all arguments are annotated + if len(type_annotations) != len(f_argnames): + raise SyntaxError( + 'Either none or all DaCe program parameters must ' + + 'have type annotations') + return {k: _create_datadescriptor(v) for k, v in type_annotations.items()} + + +def _get_argnames(f): + """ Returns a Python function's argument names. """ + try: + return inspect.getfullargspec(f).args + except AttributeError: + return inspect.getargspec(f).args + + +def _compile_module(s, name=''): + """ Compiles a string representing a python module (file or code) and + returns the resulting global objects as a dictionary mapping name->val. + @param name: Optional name for better error message handling. + """ + + gen_module = {} + code = compile(s, name, 'exec') + exec(code, gen_module) + return gen_module + + +def parse_from_file(filename, *compilation_args): + """ Try to parse all DaCe programs in `filename` and return a list of + obtained SDFGs. Raises exceptions in case of compilation errors. + Also accepts optional compilation arguments containing types and symbol + values. + """ + + with open(filename, 'r') as f: + code = f.read() + + mod = _compile_module(code, filename) + + programs = [ + program for program in mod.values() + if isinstance(program, DaceProgram) + ] + + return [parse_function(p, *compilation_args) for p in programs] + + +def parse_from_function(function, *compilation_args, strict=None): + """ Try to parse a DaceProgram object and return the `dace.SDFG` object + that corresponds to it. + @param function: DaceProgram object (obtained from the `@dace.program` + decorator). + @param compilation_args: Various compilation arguments e.g. types. + @param strict: Whether to apply strict transformations or not (None + uses configuration-defined value). + @return: The generated SDFG object. + """ + if not isinstance(function, DaceProgram): + raise TypeError( + 'Function must be of type dace.frontend.python.DaceProgram') + + # Obtain parsed DaCe program + pdp, modules = function.generate_pdp(*compilation_args) + + # Create an empty SDFG + sdfg = SDFG(pdp.name, pdp.argtypes) + + sdfg.set_sourcecode(pdp.source, 'python') + + # Populate SDFG with states and nodes, according to the parsed DaCe program + + # 1) Inherit dependencies and inject tasklets + # 2) Traverse program graph and recursively split into states, + # annotating edges with their transition conditions. + # 3) Add arrays, streams, and scalars to the SDFG array store + # 4) Eliminate empty states with no conditional outgoing transitions + # 5) Label states in topological order + # 6) Construct dataflow graph for each state + + # Step 1) + for primitive in pdp.children: + depanalysis.inherit_dependencies(primitive) + + # Step 2) + state_primitives = depanalysis.create_states_simple(pdp, sdfg) + + # Step 3) + for dataname, datadesc in pdp.all_arrays().items(): + sdfg.add_datadesc(dataname, datadesc) + + # Step 4) Absorb next state into current, if possible + oldstates = list(sdfg.topological_sort(sdfg.start_state)) + for state in oldstates: + if state not in sdfg.nodes(): # State already removed + continue + if sdfg.out_degree(state) == 1: + edge = sdfg.out_edges(state)[0] + nextState = edge.dst + if not edge.data.is_unconditional(): + continue + if sdfg.in_degree(nextState) > 1: # If other edges point to state + continue + if len(state_primitives[nextState]) > 0: # Don't fuse full states + continue + + outEdges = list(sdfg.out_edges(nextState)) + for e in outEdges: + # Construct new edge from the current assignments, new + # assignments, and new conditions + newEdge = copy.deepcopy(edge.data) + newEdge.assignments.update(e.data.assignments) + newEdge.condition = e.data.condition + sdfg.add_edge(state, e.dst, newEdge) + sdfg.remove_node(nextState) + + # Step 5) + stateList = sdfg.topological_sort(sdfg.start_state) + for i, state in enumerate(stateList): + if state.label is None or state.label == "": + state.set_label("s" + str(i)) + + # Step 6) + for i, state in enumerate(stateList): + depanalysis.build_dataflow_graph(sdfg, state, state_primitives[state], + modules) + + # Fill in scope entry/exit connectors + sdfg.fill_scope_connectors() + + # Memlet propagation + if sdfg.propagate: + labeling.propagate_labels_sdfg(sdfg) + + # Drawing the SDFG before strict transformations + sdfg.draw_to_file(recursive=True) + + # Apply strict transformations automatically + if (strict == True + or (strict is None + and Config.get_bool('optimizer', 'automatic_state_fusion'))): + sdfg.apply_strict_transformations() + + # Drawing the SDFG (again) to a .dot file + sdfg.draw_to_file(recursive=True) + + # Validate SDFG + sdfg.validate() + + return sdfg + + +class DaceProgram: + """ A data-centric program object, obtained by decorating a function with + `@dace.program`. """ + + def __init__(self, f, args, kwargs): + self.f = f + self.args = args + self.kwargs = kwargs + self._name = f.__name__ + + @property + def name(self): + return self._name + + def to_sdfg(self, *args, strict=None): + """ Parses the DaCe function into an SDFG. """ + return parse_from_function(self, *args, strict=strict) + + def compile(self, *args, strict=None, specialize=None): + """ Convenience function that parses and compiles a DaCe program. """ + sdfg = parse_from_function(self, *args, strict=strict) + return sdfg.compile(specialize=specialize) + + def __call__(self, *args, strict=None, specialize=None): + """ Convenience function that parses, compiles, and runs a DaCe + program. """ + binaryobj = self.compile(*args, strict=strict, specialize=specialize) + return binaryobj(*args) + + def generate_pdp(self, *compilation_args): + """ Generates the parsed AST representation of a DaCe program. + @param compilation_args: Various compilation arguments e.g., types. + @return: A 2-tuple of (program, modules), where `program` is a + `dace.astnodes._ProgramNode` representing the parsed DaCe + program, and `modules` is a dictionary mapping imported + module names to their actual module names (for maintaining + import aliases). + """ + dace_func = self.f + args = self.args + argnames = _get_argnames(dace_func) + + if not argnames: + raise SyntaxError( + 'DaCe program must contain at least one parameter') + + # If exist, obtain type annotations (for compilation) + argtypes = _get_type_annotations(dace_func, argnames, args) + + # Parse argument types from call + if not argtypes: + if not compilation_args: + raise SyntaxError( + 'DaCe program compilation requires either type annotations ' + 'or arrays') + + # Parse compilation arguments + if len(compilation_args) != len(argnames): + raise SyntaxError( + 'Arguments must match DaCe program parameters (expecting ' + + str(len(argnames)) + ')') + argtypes = { + k: _create_datadescriptor(v) + for k, v in zip(argnames, compilation_args) + } + ############################################# + + # Parse allowed global variables + # (for inferring types and values in the DaCe program) + global_vars = { + k: v + for k, v in dace_func.__globals__.items() if types.isallowed(v) + } + modules = { + k: v.__name__ + for k, v in dace_func.__globals__.items() + if types.ismodule_and_allowed(v) + } + modules['builtins'] = '' + + # Add symbols as globals with their actual names (sym_0 etc.) + global_vars.update({ + v.name: v + for k, v in global_vars.items() if isinstance(v, symbolic.symbol) + }) + + # Add keyword arguments as additional globals + global_vars.update( + {k: v + for k, v in self.kwargs.items() if types.isallowed(v)}) + + argtypes_ordered = OrderedDict() + for param in argnames: + argtypes_ordered[param] = argtypes[param] + + # Parse AST to create the SDFG + pdp = astparser.parse_dace_program(dace_func, argtypes_ordered, + global_vars, modules) + + # Transform parsed DaCe code into a DaCe program (Stateful DFG) + return pdp, modules diff --git a/dace/frontend/python/simulator.py b/dace/frontend/python/simulator.py new file mode 100644 index 0000000000..4e9f3a3354 --- /dev/null +++ b/dace/frontend/python/simulator.py @@ -0,0 +1,703 @@ +""" A Python simulator for DaCe programs. Currently reads and runs Python + functions rather than any SDFG. """ + +from __future__ import print_function +import ast +import copy +from functools import wraps +import inspect +import numpy +import sys +import numpy + +from dace import data, symbolic, types +from dace.config import Config +from dace.frontend.python import astparser, astnodes, astutils, ndloop, ndarray +from dace.frontend.python.astutils import unparse +from dace.frontend.python.parser import DaceProgram + + +def simulate(dace_program: DaceProgram, *args): + """ Simulate a DaCe program using Python. + @param dace_program: A program function annotated with `@dace.program`. + @param *args: Program arguments to pass. + """ + pdp, modules = dace_program.generate_pdp() + + # Transform the decorated AST into working python code (annotated so + # that debugging works) + simulated_ast = SimulatorTransformer(pdp).visit(pdp.ast) + mod = ast.Module(body=simulated_ast, lineno=1) + mod = ast.fix_missing_locations(mod) + + # Compile the transformed AST + codeobj = compile(mod, pdp.filename, 'exec') + + fname = dace_program.name + + if Config.get_bool('debugprint'): + print("Simulating DaCe program with name", fname) + + param_symbols = {} + + if len(pdp.params) != len(args): + raise SyntaxError('Argument number mismatch in \'' + fname + + '\', expecting ' + str(len(args))) + + ################################################################## + # Disallow external variables + # EXCEPTIONS: + # * The dace module ('import dace') + # * The math module ('import math') + # * Constants (types int, float, dace.int*, dace.float*) + # * DaCe symbols that have been defined in @dace.program args + ################################################################## + + f_globals = {} + + # WORKAROUND: Works around a bug in CPython 2.x where True and + # False are undefined + f_globals['True'] = True + f_globals['False'] = False + ###################### + + # Allow certain namespaces/modules and constants + f_globals.update(pdp.globals) + + # Resolve symbols + symbols = {} + symbols.update(symbolic.getsymbols( + args)) # from parameter values (externally defined as "dace.symbol") + symbols.update(param_symbols) # from parameter values (constant inputs) + + resolve = {} + for gname, gval in f_globals.items(): + if isinstance(gval, symbolic.symbol): + if gval.name in symbols: + resolve[gname] = gval.get() # Raise exception if undefined + else: + resolve[gname] = None # Mark unrelated symbols for removal + + f_globals.update(resolve) + + # Remove unrelated symbols from globals + for rk, rv in resolve.items(): + if rv is None: + del f_globals[rk] + + # Resolve symbols in arguments as well + newargs = tuple(symbolic.eval(a) for a in args) + ################################################################## + + # Store parameter objects + pdp.arrayobjs = { + k: v + for k, v in zip(pdp.params, newargs) if isinstance(v, ndarray.ndarray) + } + + # Simulate f + ################################ + # Obtain function object + gen_module = {} + gen_module.update(f_globals) + exec(codeobj, gen_module) + cfunc = gen_module[fname] + + # Run function + result = cfunc(*newargs) + ################################ + + return result + + +class RangeStorage: + """ Range storage object that is injected to the `_` variable in order to + determine DaCe primitive extents at runtime. """ + + def __init__(self): + self.range = [] + + def __getitem__( + self, + key): # Set object's range every time it is called with a range + self.range = key + return self + + +def converttype(argument, cvt_type, argname): + """ Helper function to convert a scalar argument to its type. """ + if isinstance(argument, ndarray.ndarray): + return argument + + # Convert type + converted = cvt_type.type(argument) + + # Try to cast back to the original type. If the value has changed + # (e.g., out of bounds, lost precision), raise exception + origtype = type(argument) + if origtype(converted) != argument: + raise TypeError('Type conversion of argument \'' + argname + + '\' resulted in loss of precision, please ' + + 'cast explicitly before calling program') + + return converted + + +def _copy_location(newnode, node): + return ast.fix_missing_locations(ast.copy_location(newnode, node)) + + +class SimulatorTransformer(ast.NodeTransformer): + """ A Python AST transformer that converts a DaCe program into runnable + Python code for the simulator. """ + + def __init__(self, pdp): + self.pdp = pdp + self.curprim = None + self.module_name = None + self.storeOnAssignment = {} # Mapping from local names to memlets + self.accumOnAssignment = {} # Mapping from local names to memlets + self.curchild = -1 + + # Visiting a DaCe primitive + def visit_FunctionDef(self, node): + after_nodes = [] + + if self.curprim is None: + self.curprim = self.pdp + self.curchild = -1 + if isinstance(node.decorator_list[0], ast.Call): + self.module_name = node.decorator_list[0].func.value.id + else: + self.module_name = node.decorator_list[0].value.id + # Strip decorator + del node.decorator_list[0] + + oldchild = self.curchild + oldprim = self.curprim + + else: + if len(node.decorator_list) == 0: + return self.generic_visit(node) + dec = node.decorator_list[0] + if isinstance(dec, ast.Call): + decname = astparser.rname(dec.func.attr) + else: + decname = astparser.rname(dec.attr) + + if decname in [ + 'map', 'async_map', 'reduce', 'async_reduce', 'consume', + 'async_consume', 'tasklet', 'async_tasklet', 'iterate', + 'loop', 'conditional' + ]: + self.curchild += 1 + + oldchild = self.curchild + oldprim = self.curprim + self.curprim = self.curprim.children[self.curchild] + self.curchild = -1 + + if isinstance(self.curprim, astnodes._MapNode): + newnode = \ + _copy_location(ast.For(target=ast.Tuple(ctx=ast.Store(), + elts=[ast.Name(id=name, ctx=ast.Store()) for name in self.curprim.params]), + iter=ast.parse('%s.ndrange(%s)' % (self.module_name, self.curprim.range.pystr())).body[0].value, + body=node.body, orelse=[]), + node) + node = newnode + elif isinstance(self.curprim, astnodes._ConsumeNode): + stream = self.curprim.stream + if isinstance(self.curprim.stream, ast.AST): + stream = unparse(self.curprim.stream) + if '[' not in stream: + stream += '[0]' + + newnode = \ + _copy_location(ast.While( + test=ast.parse('len(%s) > 0' % stream).body[0].value, + body=node.body, orelse=[]), + node) + node = newnode + node.body.insert( + 0, + _copy_location( + ast.parse('%s = %s.popleft()' % (str( + self.curprim.params[0]), stream)).body[0], + node)) + + elif isinstance(self.curprim, astnodes._TaskletNode): + # Strip decorator + del node.decorator_list[0] + + newnode = \ + _copy_location(ast.parse('if True: pass').body[0], node) + newnode.body = node.body + newnode = ast.fix_missing_locations(newnode) + node = newnode + elif isinstance(self.curprim, astnodes._ReduceNode): + in_memlet = self.curprim.inputs['input'] + out_memlet = self.curprim.outputs['output'] + # Create reduction call + params = [unparse(p) for p in node.decorator_list[0].args] + params.extend([ + unparse(kp) for kp in node.decorator_list[0].keywords + ]) + reduction = ast.parse( + '%s.simulator.simulate_reduce(%s, %s)' % + (self.module_name, node.name, + ', '.join(params))).body[0] + reduction = _copy_location(reduction, node) + reduction = ast.increment_lineno(reduction, + len(node.body) + 1) + reduction = ast.fix_missing_locations(reduction) + + # Strip decorator + del node.decorator_list[0] + + after_nodes.append(reduction) + elif isinstance(self.curprim, astnodes._IterateNode): + newnode = \ + _copy_location(ast.For(target=ast.Tuple(ctx=ast.Store(), + elts=[ast.Name(id=name, ctx=ast.Store()) for name in self.curprim.params]), + iter=ast.parse('%s.ndrange(%s)' % (self.module_name, self.curprim.range.pystr())).body[0].value, + body=node.body, orelse=[]), + node) + newnode = ast.fix_missing_locations(newnode) + node = newnode + elif isinstance(self.curprim, astnodes._LoopNode): + newnode = \ + _copy_location(ast.While(test=node.decorator_list[0].args[0], + body=node.body, orelse=[]), + node) + newnode = ast.fix_missing_locations(newnode) + node = newnode + else: + raise RuntimeError('Unimplemented primitive %s' % decname) + else: + return self.generic_visit(node) + + newbody = [] + end_stmts = [] + substitute_stmts = [] + # Incrementally build new body from original body + for stmt in node.body: + if isinstance(stmt, ast.Expr): + res, append, prepend = self.VisitTopLevelExpr(stmt) + if res is not None: + newbody.append(res) + if append is not None: + end_stmts.extend(append) + if prepend is not None: + substitute_stmts.extend(prepend) + else: + subnodes = self.visit(stmt) + if subnodes is not None: + if isinstance(subnodes, list): + newbody.extend(subnodes) + else: + newbody.append(subnodes) + node.body = newbody + end_stmts + + self.curchild = oldchild + self.curprim = oldprim + + substitute_stmts.append(node) + if len(after_nodes) > 0: + return substitute_stmts + after_nodes + return substitute_stmts + + def VisitTopLevelExpr(self, node): + # DaCe memlet expression + if isinstance(node.value, ast.BinOp): + rhs = node.value.right + lhs = node.value.left + arrays = self.curprim.arrays() + + if isinstance(node.value.op, ast.LShift): + # Dynamic access. Emit nothing and load memory on encounter + if isinstance(rhs, ast.Call) and ast.literal_eval( + rhs.args[0]) == -1: + array_name = rhs.func.id + stripped_subscript = '%s[:]' % (array_name) + self.storeOnAssignment[node.value.left.id] = \ + ast.parse(stripped_subscript).body[0].value + return None, None, None + + if isinstance(rhs, ast.Subscript) and isinstance( + rhs.value, ast.Call): + + # Dynamic access. Emit nothing and load memory on encounter + if ast.literal_eval(rhs.value.args[0]) == -1: + array_name = rhs.value.func.id + stripped_subscript = '%s[%s]' % (array_name, + unparse(rhs.slice)) + self.storeOnAssignment[node.value.left.id] = \ + ast.parse(stripped_subscript).body[0].value + return None, None, None + + rhs = ast.Subscript( + value=rhs.value.func, ctx=ast.Load(), slice=rhs.slice) + + result = _copy_location( + ast.Assign(targets=[node.value.left], value=rhs), node) + result.targets[0].ctx = ast.Store() + return result, None, None + # END of "a << b" + elif isinstance(node.value.op, ast.RShift): + # If the memlet refers to a sub-array (view), also add an expression to initialize it + init_expr = None + result = None + prefix = [] + + if isinstance(rhs, ast.Subscript): + # Index subscript expression ("tmp >> b(1, sum)[i,j,k,l]") + if isinstance(rhs.value, ast.Call): + # Only match expressions with possible write-conflict resolution, such as "A(...)[...]" + array_name = rhs.value.func.id + stripped_subscript = '%s[%s]' % (array_name, + unparse(rhs.slice)) + + # WCR initialization with identity value + if len(rhs.value.args) >= 3: + prefix.append( + _copy_location( + ast.parse( + '%s = %s' % + (stripped_subscript, + unparse(rhs.value.args[2]))).body[0], + node)) + + # Dynamic access. Emit nothing and store memory on assignment + if ast.literal_eval(rhs.value.args[0]) == -1: + if len(rhs.value.args) >= 2: + self.accumOnAssignment[node.value.left.id] = \ + (stripped_subscript, rhs.value.args[1]) + else: + self.storeOnAssignment[node.value.left.id] = \ + ast.parse(stripped_subscript).body[0].value + return init_expr, None, prefix + + # Make sure WCR function exists + if len(rhs.value.args) >= 2: + result = ast.parse( + '%s = (%s)(%s, %s)' % + (stripped_subscript, unparse( + rhs.value.args[1]), stripped_subscript, + node.value.left.id)).body[0] + result = _copy_location(result, node) + else: + result = ast.parse( + '%s = %s' % (stripped_subscript, + node.value.left.id)).body[0] + result = _copy_location(result, node) + else: + array_name = rhs.value.id + + if not isinstance(rhs.slice, ast.Index): + init_expr = _copy_location( + ast.Assign( + targets=[ + ast.Name( + id=node.value.left.id, ctx=ast.Store()) + ], + value=ast.Subscript( + value=ast.Name( + id=array_name, ctx=ast.Load()), + slice=rhs.slice, + ctx=ast.Load())), node) + elif not isinstance(rhs, ast.Subscript): + if isinstance(rhs, ast.Call): + array_name = rhs.func + else: + array_name = rhs + + lhs_name = lhs.id + + # In case of "tmp >> array", write "array[:]" + if node.value.left.id in self.curprim.transients: + init_expr = None + # If reading from a single stream ("b << stream") + elif (array_name.id in arrays + and isinstance(arrays[array_name.id], data.Stream)): + if arrays[array_name.id].shape == [1]: + init_expr = _copy_location( + ast.parse('{v} = {q}[0]'.format( + v=lhs_name, q=array_name.id)).body[0], + node) + return init_expr, None, [] + else: + init_expr = _copy_location( + ast.Assign( + targets=[ + ast.Name(id=lhs_name, ctx=ast.Store()) + ], + value=ast.Subscript( + value=ast.Name( + id=array_name.id, ctx=ast.Load()), + slice=ast.Slice( + lower=None, upper=None, step=None), + ctx=ast.Load())), node) + + # If we are setting a stream's sink + if lhs_name in arrays and isinstance( + arrays[lhs_name], data.Stream): + result = ast.parse( + '{arr}[0:len({q}[0])] = list({q}[0])'.format( + arr=rhs.id, q=lhs.id)).body[0] + result = _copy_location(result, node) + + # If WCR function exists + elif isinstance(rhs, ast.Call) and len(rhs.args) >= 2: + # WCR initialization with identity value + if len(rhs.args) >= 3: + prefix.append( + _copy_location( + ast.parse('%s[:] = %s' % + (array_name.id, + unparse(rhs.args[2]))).body[0], + node)) + + # Dynamic access. Emit nothing and store memory on assignment + if ast.literal_eval(rhs.args[0]) == -1: + self.accumOnAssignment[lhs.id] = (array_name.id, + rhs.args[1]) + return init_expr, None, prefix + + result = ast.parse( + '%s[:] = (%s)(%s[:], %s)' % + (array_name.id, unparse(rhs.args[1]), + array_name.id, node.value.left.id)).body[0] + result = _copy_location(result, node) + + else: + result = _copy_location( + ast.Assign( + targets=[ + ast.Subscript( + value=ast.Name( + id=array_name.id, ctx=ast.Load()), + slice=ast.Slice( + lower=None, upper=None, step=None), + ctx=ast.Store()) + ], + value=node.value.left), node) + + if result is None: + result = _copy_location( + ast.Assign( + targets=[node.value.right], value=node.value.left), + node) + result.targets[0].ctx = ast.Store() + return init_expr, [result], prefix + # END of "a >> b" + + return self.generic_visit(node), [], None + + def visit_Name(self, node): + if node.id in self.storeOnAssignment: + subscript = self.storeOnAssignment[node.id] + newnode = copy.deepcopy(subscript) + newnode.ctx = node.ctx + return _copy_location(newnode, node) + + return self.generic_visit(node) + + def visit_Assign(self, node): + if astutils.rname(node.targets[0]) in self.accumOnAssignment: + var_name = astutils.rname(node.targets[0]) + array_name, accum = self.accumOnAssignment[var_name] + if isinstance(node.targets[0], ast.Subscript): + array_name += '[' + unparse(node.targets[0].slice) + ']' + if '[' not in array_name: + array_name += '[:]' + + newnode = ast.parse('{out} = {accum}({out}, {val})'.format( + out=array_name, accum=unparse(accum), + val=unparse(node.value))).body[0] + newnode = _copy_location(newnode, node) + return newnode + + return self.generic_visit(node) + + def visit_Call(self, node): + if '.push' in astutils.rname(node.func): + node.func.attr = 'append' + return self.generic_visit(node) + + # Control flow: for-loop is the same as dace.iterate in the right context + def visit_For(self, node): + if not isinstance(self.curprim, astnodes._DataFlowNode): + self.curchild += 1 + + oldchild = self.curchild + oldprim = self.curprim + self.curprim = self.curprim.children[self.curchild] + self.curchild = -1 + + newbody = [] + end_stmts = [] + substitute_stmts = [] + # Incrementally build new body from original body + for stmt in node.body: + if isinstance(stmt, ast.Expr): + res, append, prepend = self.VisitTopLevelExpr(stmt) + if res is not None: + newbody.append(res) + if append is not None: + end_stmts.extend(append) + if prepend is not None: + substitute_stmts.extend(prepend) + else: + subnodes = self.visit(stmt) + if subnodes is not None: + if isinstance(subnodes, list): + newbody.extend(subnodes) + else: + newbody.append(subnodes) + node.body = newbody + end_stmts + substitute_stmts.append(node) + + self.curchild = oldchild + self.curprim = oldprim + return substitute_stmts + return self.generic_visit(node) + + # Control flow: while-loop is the same as dace.loop in the right context + def visit_While(self, node): + return self.visit_For(node) + + # Control flow: if-condition is the same as dace.conditional in the right context + def visit_If(self, node): + if not isinstance(self.curprim, astnodes._DataFlowNode): + self.curchild += 1 + + oldchild = self.curchild + oldprim = self.curprim + self.curprim = self.curprim.children[self.curchild] + self.curchild = -1 + + newbody = [] + end_stmts = [] + substitute_stmts = [] + # Incrementally build new body from original body + for stmt in node.body: + if isinstance(stmt, ast.Expr): + res, append, prepend = self.VisitTopLevelExpr(stmt) + if res is not None: + newbody.append(res) + if append is not None: + end_stmts.extend(append) + if prepend is not None: + substitute_stmts.extend(prepend) + else: + subnodes = self.visit(stmt) + if subnodes is not None: + if isinstance(subnodes, list): + newbody.extend(subnodes) + else: + newbody.append(subnodes) + node.body = newbody + end_stmts + + self.curchild = oldchild + self.curprim = oldprim + + # Process 'else'/'elif' statements + if len(node.orelse) > 0: + self.curchild += 1 + + oldchild = self.curchild + oldprim = self.curprim + self.curprim = self.curprim.children[self.curchild] + self.curchild = -1 + + newbody = [] + end_stmts = [] + # Incrementally build new body from original body + for stmt in node.orelse: + if isinstance(stmt, ast.Expr): + res, append, prepend = self.VisitTopLevelExpr(stmt) + if res is not None: + newbody.append(res) + if append is not None: + end_stmts.extend(append) + if prepend is not None: + substitute_stmts.extend(prepend) + else: + subnodes = self.visit(stmt) + if subnodes is not None: + if isinstance(subnodes, list): + newbody.extend(subnodes) + else: + newbody.append(subnodes) + node.orelse = newbody + end_stmts + + self.curchild = oldchild + self.curprim = oldprim + + substitute_stmts.append(node) + return substitute_stmts + + return self.generic_visit(node) + + +def simulate_reduce(op, in_array, out_array, axis=None, identity=None): + inshape = numpy.shape(in_array) + outshape = numpy.shape(out_array) + + # Argument validation + if axis is None and (len(outshape) != 1 or outshape[0] != 1): + raise RuntimeError("Cannot reduce to non-scalar value") + if axis is not None and (axis < 0 or axis >= len(in_array.shape)): + raise RuntimeError("Cannot reduce in nonexistent axis " + str(axis)) + + unreduced = outshape[:axis] + (inshape[axis], ) + outshape[axis:] + if unreduced != inshape: + raise RuntimeError("Incompatible shapes in reduction: " + + str(inshape) + " -> " + str(outshape)) + # End of argument validation + + # Reduce everything + if axis is None: + storevalue = True + + # If we have an initial value to insert + if identity is not None: + out_array[0] = identity + storevalue = False + + for i in numpy.nditer(in_array): + if storevalue: # If no identity value given, store first value as output + out_array[0] = i + storevalue = False + else: + out_array[0] = op(out_array[0], i) + + else: # Reduce a single axis + storevalue = True + + # If we have an initial value to insert + if identity is not None: + # Store identity scalar in output array + out_array[:] = identity + storevalue = False + + # Determine reduction slice (A[:,:,...,:,i,:,...,:]) + red_slice = [slice(None, None, None) for i in inshape] + for i in ndloop.xxrange(inshape[axis]): + red_slice[axis] = slice(i, i + 1, None) + + inslice = in_array[red_slice] + + if storevalue: + # Store initial value + for arrout, arrin in zip( + numpy.nditer(out_array, op_flags=['readwrite']), + numpy.nditer(inslice)): + arrout[...] = arrin + storevalue = False + else: + # Reduce entire (N-1)-dimensional tensor for the given slice + for arrout, arrin in zip( + numpy.nditer(out_array, op_flags=['readwrite']), + numpy.nditer(inslice)): + arrout[...] = op(arrout, arrin) diff --git a/dace/frontend/tensorflow/__init__.py b/dace/frontend/tensorflow/__init__.py new file mode 100644 index 0000000000..910095ba68 --- /dev/null +++ b/dace/frontend/tensorflow/__init__.py @@ -0,0 +1 @@ +from .tensorflow import * diff --git a/dace/frontend/tensorflow/tensorflow.py b/dace/frontend/tensorflow/tensorflow.py new file mode 100644 index 0000000000..939506da21 --- /dev/null +++ b/dace/frontend/tensorflow/tensorflow.py @@ -0,0 +1,2579 @@ +# -*- coding: utf-8 -*- +# Author: Roman Haag + +# TODO: This code should undergo major refactoring + +import dace +from dace.memlet import Memlet, EmptyMemlet +from dace import SDFG, SDFGState +from dace.graph.nodes import Tasklet, NestedSDFG + +import numpy as np +from collections import OrderedDict +import re + +try: + import tensorflow as tf +except ImportError: + raise ImportError('Cannot use Tensorflow frontend without Tensorflow, ' + + 'please install: https://www.tensorflow.org/install/') + +from tensorflow.python.framework import tensor_util + + +# http://stackoverflow.com/q/3844948/ +def _checkEqualIvo(lst): + return not lst or lst.count(lst[0]) == len(lst) + + +def _tensortype(tensor: tf.Tensor): + """ Returns a numpy type from a given TF tensor. """ + + # Heuristics to determine op type + if isinstance(tensor, tf.Operation): + if len(tensor.outputs) == 1: + tensor = tensor.outputs[0] + elif len(tensor.inputs) == 1: + tensor = tensor.inputs[0] + elif _checkEqualIvo([inp.dtype for inp in tensor.inputs]): + tensor = tensor.inputs[0] + else: + try: + dtype = tensor.get_attr('T') + if dtype.as_numpy_dtype == object: + raise NotImplementedError( + 'Type %s is not a valid numpy type' % str(dtype)) + return dtype.as_numpy_dtype + except ValueError: + pass + raise TypeError('Ambiguous type for operation %s' % tensor) + + if tensor.dtype.as_numpy_dtype == object: + raise NotImplementedError( + 'Type %s is not a valid numpy type' % str(tensor.dtype)) + + if (tensor.dtype.is_bool): + return np.int32 + + return tensor.dtype.as_numpy_dtype + + +def _tensorshape(tensor: tf.Tensor): + if tensor.shape.dims is None or tensor.shape.dims == []: + return 1 # Scalar + return tensor.shape + + +def _string_builder(string): + """ To match DaCe variable naming conventions, replaces all undesired + characters with "_". + """ + newstring = string + if (string[0].isdigit()): + newstring = "_" + string + out = re.sub('[^a-zA-Z0-9_]', '_', newstring) + return out + + +def _name(tensor_or_op): + if isinstance(tensor_or_op, tf.Operation): + return None + return _string_builder(tensor_or_op.name) + + +_LASTSESSION = 0 + + +class TFSession: + def __init__(self, name: str = 'tfsession', seed: int = None, config=None): + """ Creates a DaCe Tensorflow session. + @param name: (optional) The name of the resulting SDFG. + @param seed: (optional) Fix random seed. + """ + self._internal_session = tf.Session(config=config) + + # Set for bookkeeping of already visited nodes + self.visitedNodes = set() + + # Reinit state only used in training mode + self.reinitState = None + + # Different input dictionaries + self.constDict = dict() + self.varDict = dict() + self.inpDict = dict() + self.reinitDict = dict() + self.initDict = dict() + + self.training = False + self.iterations = 1 + self.seed = seed + self.graph = SDFG(name) + self.kill = False + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + pass + + def train(self, optimizer, initializer, iterations, feed_dict, nodes=None): + """ Trains a subgraph for the specified number of iterations and + returns requested nodes after training. + + @param optimizer: A TensorFlow tf.Optimizer node. + @param initializer: Either a list of global and local initializers + or one initializer. + @param iterations: Number of training steps. + @param feed_dict: Dictionary representing input values and arrays + to feed in to the evaluator. + @param nodes: (optional) A TensorFlow node or an iterable + (e.g. list) of nodes to evaluate. + @return: A 2-tuple of (varDict, values) - the first is a dictionary + of all variables used in the network in arbitrary order, + and the second is a tuple of values in the same order as + `nodes`. + """ + + # Initialize a new SDFG + self.graph = SDFG(self.graph.name) + self.graph.propagate = False + self.state = SDFGState("s0", self.graph) + self.graph.add_node(self.state) + self.iterations = iterations + state = self.state + sdfg = self.graph + outputs = [] + output_names = [] + #init state + s0 = state + #computational state" + s1 = sdfg.add_state('s1') + #emtpy exit state + s2 = sdfg.add_state('s2') + # As currently output arrays of conflict resolution do not automaticly + # get reinitialized in each state iterations, we have to manually do + # it in this state. + reinitState = sdfg.add_state("reinitialization") + self.reinitState = reinitState + #set training mode + + self.training = True + + #add edges between states + sdfg.add_edge( + s0, + s1, + dace.graph.edges.InterstateEdge(assignments=dict(__dacet1=0))) + sdfg.add_edge( + s1, reinitState, + dace.graph.edges.InterstateEdge( + condition=dace.properties.CodeProperty.from_string( + "__dacet1 <" + str(iterations - 1), + dace.types.Language.Python), + assignments={'__dacet1': '__dacet1+1'})) + sdfg.add_edge(reinitState, s1, dace.graph.edges.InterstateEdge()) + sdfg.add_edge( + s1, + s2, + dace.graph.edges.InterstateEdge( + condition=dace.properties.CodeProperty.from_string( + "__dacet1 >= " + + str(iterations - 1), dace.types.Language.Python))) + + try: + iter(initializer) + initializer = list(initializer) + except TypeError: + initializer = [initializer] + + try: + iter(nodes) + nodes = list(nodes) + except TypeError: + nodes = [nodes] + + if (not nodes == None): + try: + iter(optimizer) + optimizer = list(optimizer) + except TypeError: + optimizer = [optimizer] + + ########################### + # Prepare subgraph to process + # If only one node was given, construct a list from it + if (not nodes == [None]): + ops = [ + node if isinstance(node, tf.Operation) else node.op + for node in nodes + ] + output_names = [ + _string_builder(node.name) + if not isinstance(node, tf.Operation) else None + for node in nodes + ] + + # Visit initializer and create subgraph for init state + # If only one node was given, construct a list from it + + init = [ + i if isinstance(i, tf.Operation) else i.op for i in initializer + ] + self.visit_backwards(init) + + # Visit the rest of the nodes + self.state = s1 + state = s1 + # As we are in a new state, all variable nodes should be revisited + self.visitedNodes.clear() + self.visit_backwards(optimizer) + if (not nodes == [None]): + self.visit_backwards(ops) + ############################ + + # Remove orphan nodes and register node types + node_types = {} + for state in self.graph.nodes(): + for node in state.nodes(): + if state.in_degree(node) + state.out_degree(node) == 0: + state.remove_node(node) + if node.label in self.constDict: + del self.constDict[node.label] + elif isinstance(node, dace.graph.nodes.AccessNode): + node_types[node.data] = node.desc(self.graph).dtype.type + ############################ + # Set up arguments + sdfg_args = {} + sdfg_args.update(self.constDict) + sdfg_args.update(self.varDict) + sdfg_args.update(self.inpDict) + sdfg_args.update(self.reinitDict) + sdfg_args.update(self.initDict) + + sdfg_args.update({(k if isinstance(k, str) else + _string_builder(k.name + "_Inp")): v + for k, v in feed_dict.items()}) + + # Set scalar arguments to appropriate arrays of size 1 + sdfg_args.update({ + k: (v if isinstance(v, np.ndarray) else np.array( + v, dtype=node_types[k])) + for k, v in sdfg_args.items() + }) + + ############################ + # Create output numpy arrays + if (not nodes == [None]): + outputs = { + name: np.zeros(_tensorshape(node), dtype=_tensortype(node)) + for node, name in zip(nodes, output_names) + if name is not None and name not in sdfg_args + } + outputs.update( + {k: v + for k, v in sdfg_args.items() if k in output_names}) + + sdfg_args.update(outputs) + + ############################ + # Mark outputs as non-transients + for output in outputs: + self.graph.arrays[output].transient = False + ############################ + + # Compile and call the SDFG + self.graph.draw_to_file() + compiled_sdfg = self.graph.compile(optimizer=False) + compiled_sdfg(**sdfg_args) + ############################ + + # Return the outputs and weights + + return self.varDict, tuple( + outputs[output] if output is not None else None + for output in output_names) + + def compile(self, nodes, name=None): + """ Compiles a subgraph into a callable function, which is equivalent + to calling `run()`. + @param nodes: Node or an iterable (e.g. list) of nodes to evaluate. + @param name: Name of the SDFG to create, or None for a unique name. + @return: A function that receives a feed_dict, evaluates the nodes, + and returns a tuple of values in the same order as nodes. + """ + # Create a unique name for this session + if name is None: + global _LASTSESSION + _LASTSESSION += 1 + name = "tfsession%d" % _LASTSESSION + + # Initialize a new SDFG + self.graph = SDFG(name) + self.graph.propagate = False + self.state = SDFGState("s0", self.graph) + self.graph.add_node(self.state) + self.visitedNodes.clear() + ############################ + + # Prepare subgraph to process + total_nodes = [] + + # Determine output type + output_type = None + if not isinstance(nodes, + (list, tuple, dict)): # iter() works in TensorFlow + output_type = object + total_nodes.append(nodes) + output_names = _name(nodes) + elif isinstance(nodes, dict): + output_type = type(nodes) + output_names = {} + for k, node in nodes.items(): + try: + iter(node) + if isinstance(node, dict): + raise TypeError( + 'Dictionaries of dictionaries unsupported') + total_nodes.extend(node) + output_names[k] = type(node)(_name(n) for n in node) + except TypeError: + total_nodes.append(node) + output_names[k] = _name(node) + elif isinstance(nodes, (list, tuple)): + output_type = type(nodes) + total_nodes.extend(nodes) + output_names = output_type(_name(node) for node in nodes) + else: + raise TypeError('Unsupported type for fetches: ' + + str(type(nodes))) + + ops = [ + node if isinstance(node, tf.Operation) else node.op + for node in total_nodes + ] + total_output_names = [ + _string_builder(node.name) + if not isinstance(node, tf.Operation) else None + for node in total_nodes + ] + + self.kill = False + self.visit_backwards(ops) + if self.kill: + raise NotImplementedError('Nodes listed above are not implemented') + ############################ + + # Remove orphan nodes and register node types + node_types = {} + for state in self.graph.nodes(): + for node in state.nodes(): + if state.in_degree(node) + state.out_degree(node) == 0: + state.remove_node(node) + if node.label in self.constDict: + del self.constDict[node.label] + elif isinstance(node, dace.graph.nodes.AccessNode): + node_types[node.data] = node.desc(self.graph).dtype.type + ############################ + + # Set up arguments + sdfg_args = {} + sdfg_args.update(self.constDict) + sdfg_args.update(self.varDict) + sdfg_args.update(self.inpDict) + sdfg_args.update(self.initDict) + + # Set scalar arguments to appropriate arrays of size 1 + sdfg_args.update({ + k: (v if isinstance(v, np.ndarray) else np.array( + v, dtype=node_types[k])) + for k, v in sdfg_args.items() + }) + ############################ + + # Create output numpy arrays + outputs = { + name: np.zeros(_tensorshape(node), dtype=_tensortype(node)) + for node, name in zip(total_nodes, total_output_names) + if name is not None and name not in sdfg_args + } + outputs.update( + {k: v + for k, v in sdfg_args.items() if k in total_output_names}) + + sdfg_args.update(outputs) + + ############################ + # Mark outputs as non-transients + for output in outputs: + self.graph.arrays[output].transient = False + ############################ + + # Compile the SDFG + self.graph.fill_scope_connectors() + self.graph.draw_to_file() + compiled_sdfg = self.graph.compile(optimizer=False) + + ############################ + # Create the function that invokes the SDFG + def call_func(feed_dict={}): + invoke_args = dict( + sdfg_args, **{(k if isinstance(k, str) else + _string_builder(k.name)): v + for k, v in feed_dict.items()}) + + compiled_sdfg(**invoke_args) + + # Single output + if output_type is object: + return outputs[ + output_names] if output_names is not None else None + # Dictionary of lists/single outputs + elif output_type is dict: + out_dict = {} + for k, v in output_names.items(): + if isinstance(v, (list, tuple)): + out_dict[k] = type(v)( + outputs[vname] if vname is not None else None + for vname in v) + else: + out_dict[k] = outputs[v] if v is not None else None + return out_dict + # List of outputs + else: + return output_type( + outputs[output] if output is not None else None + for output in output_names) + + # Return the function + return call_func + + def run(self, nodes, feed_dict={}, name=None): + """ Evaluates a subgraph and returns a tuple of the evaluated nodes + (behaves similarly to sess.run). + @param nodes: Node or an iterable (e.g. list) of nodes to evaluate. + @param feed_dict: Dictionary representing input values and arrays + to feed in to the evaluator. + @param name: Name of the SDFG to create, or None for a unique name. + + @return: Tuple or dictionary of values in the same order as `nodes`. + """ + callfunc = self.compile(nodes, name=name) + return callfunc(feed_dict=feed_dict) + + def dfs_nodes(self, source): + """ Produce nodes in a depth-first-search (DFS) on a TensorFlow graph. + @param source: The source node to start from. + @return: A generator of nodes in the depth-first-search. + @note: Based on http://www.ics.uci.edu/~eppstein/PADS/DFS.py + by D. Eppstein, July 2004. + """ + + # If source is a list of nodes (or any iterable), start from all + try: + iter(source) + nodes = list(source) + except TypeError: + nodes = [source] + + visited = set() + + for start in nodes: + if start in visited: + continue + visited.add(start) + yield start + + inputSet = [inp.op for inp in start.inputs] + inputSet.extend(list(start.control_inputs)) + stack = [(start, iter(inputSet))] + while stack: + parent, children = stack[-1] + try: + child = next(children) + + if child not in visited: + yield child + visited.add(child) + + inputSet = [inp.op for inp in child.inputs] + inputSet.extend(list(child.control_inputs)) + stack.append((child, iter(inputSet))) + except StopIteration: + stack.pop() + + def visit_backwards(self, node): + """ Visit a graph from an output node backwards to the inputs. """ + for node in self.dfs_nodes(node): + if node not in self.visitedNodes: + self.visit(node) + + def visit(self, node): + """ Visit a specific node in the graph, creating the SDFG. """ + try: + func = getattr(self, "visit_" + node.type) + except AttributeError: + # Only stop processing after all node types have been visited, + # so that we know which implementations are missing. + self.kill = True + print('MISSING IMPLEMENTATION:', node.type) + if self.kill == False: + func(node) + #mark node as visited + self.visitedNodes.add(node) + + ###################################################################### + # Operator (TensorFlow graph node) visitors + + def visit_Add(self, node): + self.visit_element_wise_op(node, "+") + + def visit_Mul(self, node): + self.visit_element_wise_op(node, "*") + + def visit_Sub(self, node): + self.visit_element_wise_op(node, "-") + + def visit_RealDiv(self, node): + self.visit_element_wise_op(node, "/") + + def visit_Equal(self, node): + self.visit_element_wise_op(node, "==") + + def visit_Const(self, node): + state = self.state + label = _string_builder(node.name + "_0") + + # Create DaCe shape + shape = dace.properties.ShapeProperty.from_string( + str(_tensorshape(node.outputs[0]))) + # Create np array from tensor value + npArray = tensor_util.MakeNdarray( + node.get_attr('value')).reshape(shape) + + # Add to constDict so that it can be fed to the program + self.constDict[label] = npArray.astype(_tensortype(node)) + + nodeArray = list( + filter(lambda a: a.label == label, self.state.nodes())) + + # If node already present set it non transient, otherwise add node + if (not nodeArray): + dtype = dace.typeclass(_tensortype(node)) + state.add_array(label, shape, dtype, toplevel=True) + else: + nodeArray[0].desc(self.graph).transient = False + + def visit_NoOp(self, node): + # no op case where nothing happens + pass + + def visit_Pack(self, node): + # we do nothing with this op + pass + + def visit_StridedSlice(self, node): + # we do nothing with this op + pass + + def visit_VariableV2(self, node): + + state = self.state + label = _string_builder(node.name) + "_0" + shape = dace.properties.ShapeProperty.from_string( + str(_tensorshape(node.outputs[0]))) + + try: + outputNode = state.find_node(label) + outputNode.desc(self.graph).transient = False + except (LookupError): + dtype = dace.typeclass(_tensortype(node)) + state.add_array(label, shape, dtype) + + # If not already added to the varDict, add a placeholder + # zero-initialized array to it so a value error is not triggered. + if (label not in self.varDict.keys()): + npArray = np.zeros(shape=shape) + self.varDict[label] = npArray.astype(_tensortype(node)) + + def visit_Assign(self, node): + # Simple memcopy from input1 to input0 as assign has no outputlist but + # input0 is the variable we want to assign + state = self.state + inputList = [] + inputNodes = [] + inputParams = [] + inputDims = [] + + for count, inp in enumerate(node.inputs): + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputParams.append(params) + inputDims.append(dims) + + memlet = Memlet.simple(inputNodes[0], ",".join(inputDims[0])) + state.add_edge(inputNodes[1], None, inputNodes[0], None, memlet) + + def visit_Placeholder(self, node): + + outputShape = [] + outputParams = [] + outputDims = [] + inputShape = [] + inputParams = [] + inputDims = [] + outputTensor = node.outputs[0] + state = self.state + label = _string_builder(node.name + "_0") + + # Check if the node is already in the graph and get as a list + try: + outputNode = state.find_node(label) + + except (LookupError): + outputNode = self.create_and_add_output_node(node) + + dtype = _tensortype(node) + + # If we are in training mode, we set up another map to reduce the huge + # (iterations x batchsize x size of input) input to one dimension less + if (self.training): + # Output dimensions of the map + + outputDims = self.get_default_dims(outputTensor) + outputParams = self.get_default_params(outputTensor, 1) + outputShape = list(map(str, _tensorshape(outputTensor))) + + # Prepend the iterations dimension to the input (t1=iterations) + inputShape.append(str(self.iterations)) + inputShape.extend(outputShape) + inputParams.append("i0") + inputParams.extend(outputParams) + inputDims.append("__dacet1:__dacet1+1") + inputDims.extend(outputDims) + + #create node for the training examples + shape = dace.properties.ShapeProperty.from_string( + ",".join(inputShape)) + dtype = _tensortype(node) + inputNode = state.add_array( + name=label + "_Inp", shape=shape, dtype=dace.typeclass(dtype)) + + #create and add mapp + mapDict = dict(zip(inputParams, inputDims)) + inMemletDict = dict( + j0=Memlet.simple(inputNode, ",".join(inputParams))) + outMemletDict = dict( + out=Memlet.simple(outputNode, ",".join(outputParams))) + code = "out = j0" + tasklet, map_entry, map_exit = state.add_mapped_tasklet( + label, mapDict, inMemletDict, code, outMemletDict) + state.add_edge(inputNode, None, map_entry, None, + Memlet.simple(inputNode, ",".join(inputDims))) + state.add_edge(map_exit, None, outputNode, None, + Memlet.simple(outputNode, ",".join(outputDims))) + + # If training example node is not already in inputDict, add a + # zero array. This prevents DaCe from raising a key error when + # trying to call the dace function if we only execute a subgraph + # where it does not appear. This might not be necessary any longer. + if (label + "_Inp" not in self.inpDict.keys()): + self.inpDict[label + "_Inp"] = np.zeros( + tuple(map(int, (inputShape))), dtype=dtype) + + # If we are not training, set the output non transient and add to + # input dict + else: + outputNode.desc(self.graph).transient = False + self.inpDict[label] = np.zeros( + tuple(map(int, (outputNode.desc(self.graph).shape))), + dtype=dtype) + + def visit_TruncatedNormal(self, node): + # Creates a truncated normal array and adds it to initDict + state = self.state + label = _string_builder(node.name + "_0") + # Check if already in graph, set non-transient. Otherwise add to graph. + try: + outputNode = state.find_node(label) + outputNode.desc(self.graph).transient = False + + except (LookupError): + self.create_and_add_output_node(node) + + seed = 0 if self.seed is None else self.seed + + array = tf.truncated_normal( + node.outputs[0].shape, + seed=seed).eval(session=self._internal_session) + self.initDict[label] = array.astype(_tensortype(node)) + + def visit_RandomStandardNormal(self, node): + + state = self.state + label = _string_builder(node.name + "_0") + + try: + outputNode = state.find_node(label) + outputNode.desc(self.graph).transient = False + + except (LookupError): + self.create_and_add_output_node(node) + + array = tf.random_normal( + node.outputs[0].shape, + seed=self.seed).eval(session=self._internal_session) + self.initDict[label] = array.astype(_tensortype(node)) + + def visit_RandomUniform(self, node): + # Creates a random uniform array and adds it to initDict + state = self.state + label = _string_builder(node.name + "_0") + # Check if already in graph, set non-transient. Otherwise add to graph. + try: + outputNode = state.find_node(label) + outputNode.desc(self.graph).transient = False + + except (LookupError): + self.create_and_add_output_node(node) + + seed = 0 if self.seed is None else self.seed + + array = tf.random_uniform( + node.outputs[0].shape, + seed=seed).eval(session=self._internal_session) + self.initDict[label] = array.astype(_tensortype(node)) + + def visit_RandomUniformInt(self, node): + # Creates a random uniform array and adds it to initDict + state = self.state + label = _string_builder(node.name + "_0") + # Check if already in graph, set non-transient. Otherwise add to graph. + try: + outputNode = state.find_node(label) + outputNode.desc(self.graph).transient = False + + except (LookupError): + self.create_and_add_output_node(node) + + seed = 0 if self.seed is None else self.seed + + array = tf.random_uniform( + node.outputs[0].shape, + dtype=tf.as_dtype(_tensortype(node)), + minval=node.inputs[1], + maxval=node.inputs[2], + seed=seed).eval(session=self._internal_session) + self.initDict[label] = array.astype(_tensortype(node)) + + def visit_Fill(self, node): + # Fills an array with a scalar input value + state = self.state + inputList = [] + inputNodes = [] + outputList = [] + mapParams = [] + mapRange = [] + outputParams = [] + outputDims = [] + inputParams = [] + inputDims = [] + + for count, inp in enumerate(node.inputs): + # Scalar input is at position 1 + if (count == 1): + inp, params, dims = self.create_and_add_input_node(inp) + inputList.append(inp.desc(self.graph)) + inputNodes.append(inp) + inputParams.append(params) + inputDims.append(dims) + + outputList = self.create_and_add_output_node(node) + + for out in node.outputs: + params = self.get_default_params(out, 1) + dims = self.get_default_dims(out) + outputParams.append(params) + outputDims.append(dims) + + mapLabel = _string_builder(node.type) + mapParams = inputParams[0] + outputParams[0] + mapRange = inputDims[0] + outputDims[0] + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel, {'j0'}, {'out'}, "out = j0") + self.add_out_memlets(outputList, mapExit, tasklet, outputDims, + outputParams) + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + + def visit_Mean(self, node): + + inputList = [] + inputNodes = [] + outputList = [] + state = self.state + mapParams = [] + mapRange = [] + outputParams = [] + outputDims = [] + inputParams = [] + inputDims = [] + + for count, inp in enumerate(node.inputs): + if (count == 0): + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputParams.append(params) + inputDims.append(dims) + # Need to get total size of array + n = 1 + for i in (inputNode.desc(self.graph).shape): + n *= i + # Output is scalar + outputParams = ["i1"] + outputDims = ["0:1"] + outputShape = dace.properties.ShapeProperty.from_string( + str(_tensorshape(node.outputs[0]))) + outputNode = state.add_transient( + _string_builder(node.outputs[0].name), + outputShape, + dace.typeclass(_tensortype(inp)), + toplevel=True) + outputList = [] + outputList.append(outputNode) + + mapLabel = _string_builder(node.type) + mapParams = inputParams[0] + outputParams + mapRange = inputDims[0] + outputDims + + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + self.reinitCR(outputList[0], [["i0"]], [["0:1"]], "0") + tasklet = state.add_tasklet(mapLabel, {'j0'}, {'out'}, + "out = j0/" + str(n)) + self.add_out_memlets(outputList, mapExit, tasklet, [outputDims], + [outputParams], "lambda a, b: (a + b)", 0) + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + + def visit_Tile(self, node): + # Replicates input multiple times + inputList = [] + inputNodes = [] + + state = self.state + + for inp in node.inputs: + + label = _string_builder(inp.name) + try: + inputNode = state.find_node(label) + except (LookupError): + + inputNode = self.create_and_add_input_node(inp)[0] + + inputNodes.append(inputNode) + inputList.append(inputNode.desc(self.graph)) + + outputList = self.create_and_add_output_node(node) + + mapLabel = _string_builder(node.type) + outputDims = self.get_default_dims(node.outputs[0]) + outputParams = self.get_default_params(node.outputs[0]) + inputDims = self.get_default_dims(node.inputs[0]) + inputParams = [] + + for i, dim in enumerate(inputList[0].shape): + inputParams.append("i" + str(i) + "%" + str(dim)) + + mapDict = dict(zip(outputParams, outputDims)) + inMemletDict = dict( + j0=Memlet.simple(inputNodes[0], ",".join(inputParams))) + outMemletDict = dict( + out=Memlet.simple(outputList[0], ",".join(outputParams))) + code = "out = j0" + tasklet, map_entry, map_exit = state.add_mapped_tasklet( + mapLabel, mapDict, inMemletDict, code, outMemletDict) + state.add_edge(inputNodes[0], None, map_entry, None, + Memlet.simple(inputNodes[0], ",".join(inputDims))) + state.add_edge(map_exit, None, outputList[0], None, + Memlet.simple(outputList[0], ",".join(outputDims))) + + def visit_PreventGradient(self, node): + # Just a memcopy, works like visit_assign or visit_identity + state = self.state + inputList = [] + inputNodes = [] + outputList = [] + outputParams = [] + outputDims = [] + inputParams = [] + inputDims = [] + + for count, inp in enumerate(node.inputs): + #relevant input is at position 0 + if (count == 0): + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputParams.append(params) + inputDims.append(dims) + + outputList = self.create_and_add_output_node(node) + + for count, out in enumerate(node.outputs): + + dims = self.get_default_dims(out) + params = self.get_default_params(out) + outputParams.append(params) + outputDims.append(dims) + + memlet = Memlet.simple(inputNodes[0], ",".join(inputDims[0])) + state.add_edge(inputNodes[0], None, outputList[0], None, memlet) + + def visit_ExpandDims(self, node): + # Takes an N-dimensional array and adds one dimension to it with a + # length of 1. Example: (M,K) -> (1,M,K). + # We can just use DaCe memory copy to do the same + state = self.state + inputList = [] + inputNodes = [] + inputDims = [] + inputParams = [] + + for count, inp in enumerate(node.inputs): + if (count == 0): + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputDims.append(dims) + inputParams.append(params) + + outputList = self.create_and_add_output_node(node) + memlet = Memlet.simple(inputNodes[0], ",".join(inputDims[0])) + state.add_edge(inputNodes[0], None, outputList[0], None, memlet) + + def visit_ApplyGradientDescent(self, node): + + state = self.state + inputList = [] + inputNodes = [] + mapParams = [] + mapRange = [] + inputParams = [] + inputDims = [] + + for count, inp in enumerate(node.inputs): + + inputNode, params, dims = self.create_and_add_input_node(inp) + inputParams.append(params) + inputDims.append(dims) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + + mapLabel = _string_builder(node.type) + #inputList[1] is learning rate which needs its own parameter + inputParams[1] = ["i4"] + # This is the variable which is input and output of this map at the same + # time. We create the output version of it here + out = node.inputs[0] + shape = dace.properties.ShapeProperty.from_string( + str(_tensorshape(out))) + outName = _string_builder(out.name) + dtype = _tensortype(out) + outputNode = state.add_array(outName, shape, dtype) + dims = self.get_default_dims(out) + params = self.get_default_params(out) + outputList = [outputNode] + outputParams = [params] + outputDims = [dims] + + mapLabel = _string_builder(node.type) + mapParams = inputParams[0] + ["i4"] + mapRange = inputDims[0] + ["0:1"] + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel, {'j0', 'j1', 'j2'}, {'out'}, + "out = j0-(j1*j2)") + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + self.add_out_memlets(outputList, mapExit, tasklet, outputDims, + outputParams) + + def visit_MatMul(self, node): + # 2d Matrix Multiplication + inputList = [] + inputNodes = [] + state = self.state + mapParams = [] + outputParams = [[]] + mapRange = [] + outputDims = [[]] + inputParams = [[], []] + inputDims = [[], []] + + for inp in node.inputs: + inputNode = self.create_and_add_input_node(inp)[0] + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + + outputList = self.create_and_add_output_node(node) + + ndims = len(outputList[0].desc(self.graph).shape) + # Params for higher dimensions (not verified) + # (for 2d it works) + for i in range(0, ndims + 1): + if (i == ndims): + mapParams.append("i" + str(i)) + inputParams[1].append("i" + str(i)) + outputParams[0].append("i" + str(i)) + + elif (i == ndims - 1): + mapParams.append("i" + str(i)) + inputParams[0].append("i" + str(i)) + inputParams[1].append("i" + str(i)) + + elif (i == ndims - 2): + mapParams.append("i" + str(i)) + inputParams[0].append("i" + str(i)) + outputParams[0].append("i" + str(i)) + + else: + mapParams.append("i" + str(i)) + inputParams[0].append("i" + str(i)) + inputParams[1].append("i" + str(i)) + outputParams[0].append("i" + str(i)) + + for i in range(0, ndims): + inputDims[0].append(str(0) + ":" + str(node.inputs[0].shape[i])) + inputDims[1].append(str(0) + ":" + str(node.inputs[1].shape[i])) + outputDims[0].append(str(0) + ":" + str(node.outputs[0].shape[i])) + mapRange.append(str(0) + ":" + str(node.inputs[0].shape[i])) + + mapRange.append(str(0) + ":" + str(node.outputs[0].shape[ndims - 1])) + #if first input needs to be transposed + if (node.get_attr("transpose_a")): + mapRange[0], mapRange[1] = mapRange[1], mapRange[0] + inputParams[0][0], inputParams[0][1] = inputParams[0][ + 1], inputParams[0][0] + #if second input needs to be transposed + if (node.get_attr("transpose_b")): + inputParams[1][0], inputParams[1][1] = inputParams[1][ + 1], inputParams[1][0] + + mentry, mexit = state.add_map('matmul_outer', + {mapParams[1]: mapRange[1]}, + dace.ScheduleType.Sequential) + minentry, minexit = state.add_map('matmul_inner', { + mapParams[0]: mapRange[0], + mapParams[2]: mapRange[2] + }, dace.ScheduleType.CPU_Multicore) + tasklet = state.add_tasklet('mm_code', {'j0', 'j1'}, {'out'}, + 'out = j0*j1') + + for i, inp in enumerate(inputNodes): + name = "j" + str(i) + memlet = Memlet.simple(inp, ",".join(inputParams[i])) + state.add_edge(minentry, None, tasklet, name, memlet) + + for i, out in enumerate(outputList): + name = "out" + memlet = Memlet.simple( + out, + ",".join(outputParams[i]), + wcr_str='lambda a,b: a+b', + wcr_identity=0) + state.add_edge(tasklet, name, minexit, None, memlet) + + self.reinitCR(outputList[0], outputParams, outputDims, '0') + self.add_out_memlets(outputList, mexit, minexit, outputDims, + outputParams, 'lambda a,b: a+b', 0) + self.add_in_memlets(inputNodes, mentry, minentry, inputDims, + inputParams) + + def visit_element_wise_op(self, node, operation): + """ Handles all the element wise operations, supports broadcasting. """ + inputList = [] + inputNodes = [] + mapParams = [] + outputParams = [] + mapRange = [] + outputDims = [] + inputParams = [] + inputDims = [] + state = self.state + + for inp in node.inputs: + + inputNode, _, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputDims.append(dims) + + outputNodes = self.create_and_add_output_node(node) + mapLabel = _string_builder(node.type) + #create params + for inp in inputList: + inputParamsString = [] + for i, dim in enumerate(inp.shape): + #scalar case that we want to broadcast + if (str(dim) == "1"): + inputParamsString.append("0") + else: + inputParamsString.append("i" + str(i)) + + inputParams.append(inputParamsString) + + params = self.get_default_params(node.outputs[0]) + dims = self.get_default_dims(node.outputs[0]) + outputParams.append(params) + outputDims.append(dims) + + mapParams = outputParams[0] + mapRange = outputDims[0] + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel, {'j0', 'j1'}, {'out'}, + "out = j0 " + operation + " j1") + self.add_out_memlets(outputNodes, mapExit, tasklet, outputDims, + outputParams) + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + + def visit_Conv2D(self, node): + + inputList = [] + inputNodes = [] + ndims = 0 + strides = node.get_attr("strides")[1] + state = self.state + + for inp in node.inputs: + + inputNode = self.create_and_add_input_node(inp)[0] + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + + outputList = self.create_and_add_output_node(node) + ndims = len(outputList[0].desc(self.graph).shape) + mapLabel = _string_builder(node.type) + + mapParams = [] + outputParams = [] + mapRange = [] + outputDims = [[]] + inputParams = [] + inputDims = [[], []] + #create conv params + inputParams.append([ + "i0", "i1*" + str(strides) + "+i5", "i2*" + str(strides) + "+i6", + "i3" + ]) + inputParams.append(["i5", "i6", "i3", "i4"]) + outputParams.append(["i0", "i1", "i2", "i4"]) + #create conv dims + for i in range(0, ndims): + inputDims[0].append(str(0) + ":" + str(node.inputs[0].shape[i])) + inputDims[1].append(str(0) + ":" + str(node.inputs[1].shape[i])) + outputDims[0].append(str(0) + ":" + str(node.outputs[0].shape[i])) + # add a padding map for same padding(zero padding so that input and + # output of convolution have the same size) + if (str(node.get_attr("padding"))[2:-1] == "SAME"): + paddedInput, paddedDims = self.inputPadding( + node, inputNodes[0], inputList[0], outputList[0].desc( + self.graph).shape[1], inputList[1].shape[0], strides, + inputDims[0]) + inputDims[0] = paddedDims + inputList[0] = paddedInput + + mapParams = outputParams[0] + mapParams2 = inputParams[1][:-1] + mapRange = outputDims[0] + mapRange2 = inputDims[1][:-1] + + mapEntry, mapExit = state.add_map(mapLabel + "_outer", + dict(zip(mapParams, mapRange))) + mapEntry2, mapExit2 = state.add_map(mapLabel + "_inner", + dict(zip(mapParams2, mapRange2))) + self.reinitCR(outputList[0], outputParams, outputDims, "0") + tasklet = state.add_tasklet(mapLabel, {'j0', 'j1'}, {'out'}, + "out = j0 * j1") + self.add_out_memlets(outputList, mapExit, mapExit2, outputDims, + outputParams, 'lambda a,b: a+b', 0) + self.add_in_memlets(inputNodes, mapEntry, mapEntry2, inputDims, + inputParams) + #add memlets from inner map to tasklet + for i, inp in enumerate(inputNodes): + name = "j" + str(i) + memlet = Memlet.simple(inp, ",".join(inputParams[i])) + state.add_edge(mapEntry2, None, tasklet, name, memlet) + #add memelets from tasklet to cr + for i, out in enumerate(outputList): + name = "out" + memlet = Memlet.simple( + out, + ",".join(outputParams[i]), + wcr_str='lambda a,b: a+b', + wcr_identity=0) + state.add_edge(tasklet, name, mapExit2, None, memlet) + + def visit_BiasAdd(self, node): + + inputList = [] + inputNodes = [] + state = self.state + + for inp in node.inputs: + inputNode = self.create_and_add_input_node(inp)[0] + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + + outputList = self.create_and_add_output_node(node) + dims = outputList[0].desc(self.graph).shape + + mapLabel = _string_builder(node.type) + mapParams = [] + outputParams = [] + mapRange = [] + outputDims = [] + inputParams = [[], []] + inputDims = [[], []] + + params = self.get_default_params(node.outputs[0]) + dims = self.get_default_dims(node.outputs[0]) + outputParams.append(params) + outputDims.append(dims) + + mapParams = outputParams[0] + inputParams[0] = outputParams[0] + #the bias matches the last dimension of input resp. output + inputParams[1] = [mapParams[-1]] + mapRange = outputDims[0] + inputDims[0] = outputDims[0] + inputDims[1] = ["0:" + str(node.inputs[1].shape[0])] + + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel, {'j0', 'j1'}, {'out'}, + "out = j0 + j1") + self.add_out_memlets(outputList, mapExit, tasklet, outputDims, + outputParams) + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + + def visit_MaxPool(self, node): + + inputList = [] + inputNodes = [] + dims = [] + inputDims = [] + strides = node.get_attr("strides")[1] + ksize = node.get_attr("ksize")[1] + state = self.state + + for inp in node.inputs: + inputNode, _, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputDims.append(dims) + inputParams = [[ + "i0", "i1*" + str(strides) + "+i4", "i2*" + str(strides) + "+i5", + "i3" + ]] + + outputParams = [] + outputDims = [] + outputList = self.create_and_add_output_node(node) + dims = self.get_default_dims(node.outputs[0]) + params = self.get_default_params(node.outputs[0]) + outputDims.append(dims) + outputParams.append(params) + + mapLabel = _string_builder(node.type) + mapParams1 = outputParams[0] + mapRange1 = outputDims[0] + mapParams2 = ["i4", "i5"] + mapRange2 = ["0:" + str(ksize), "0:" + str(ksize)] + + mapEntry, mapExit = state.add_map(mapLabel + "_outer", + dict(zip(mapParams1, mapRange1))) + mapEntry2, mapExit2 = state.add_map(mapLabel + "_inner", + dict(zip(mapParams2, mapRange2))) + tasklet = state.add_tasklet(mapLabel, {'j0'}, {'out'}, "out = j0") + self.reinitCR(outputList[0], outputParams, outputDims, "-9999999999") + self.add_out_memlets(outputList, mapExit, mapExit2, outputDims, + outputParams, 'lambda a,b: max(a,b)', -9999999999) + self.add_in_memlets(inputNodes, mapEntry, mapEntry2, inputDims, + inputParams) + #add memlets from inner map to tasklet + for i, inp in enumerate(inputNodes): + name = "j" + str(i) + memlet = Memlet.simple(inp, ",".join(inputParams[i])) + state.add_edge(mapEntry2, None, tasklet, name, memlet) + #add memelets from tasklet to cr + for i, out in enumerate(outputList): + name = "out" + memlet = Memlet.simple( + out, + ",".join(outputParams[i]), + wcr_str='lambda a,b: max(a,b)', + wcr_identity=-9999999999) + state.add_edge(tasklet, name, mapExit2, None, memlet) + + def visit_Relu(self, node): + + inputList = [] + inputNodes = [] + state = self.state + inputParams = [] + inputDims = [] + + for inp in node.inputs: + + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputParams.append(params) + inputDims.append(dims) + + outputList = self.create_and_add_output_node(node) + + mapLabel = _string_builder(node.type) + mapParams = [] + mapRange = [] + mapParams = inputParams[0] + mapRange = inputDims[0] + + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel, {'j0'}, {'out'}, + "out = max(dace.float32(0),j0)") + self.add_out_memlets(outputList, mapExit, tasklet, inputDims, + inputParams) + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + + def visit_ShapeN(self, node): + inputList = [] + inputNodes = [] + state = self.state + inputParams = [] + inputDims = [] + + for inp in node.inputs: + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputParams.append(params) + inputDims.append(dims) + + outputList = self.create_and_add_output_node(node) + + mapLabel = _string_builder(node.type) + for i, node in enumerate(outputList): + tasklet = state.add_tasklet( + mapLabel + str(i), {}, {'out'}, '\n'.join([ + 'out[%d] = %s' % (j, dim) + for j, dim in enumerate(inputList[i].shape) + ])) + self.state.add_edge( + tasklet, 'out', node, None, + Memlet.simple(node, '0:' + str(len(inputDims[i])))) + + def visit_Reshape(self, node): + + state = self.state + inputList = [] + inputNodes = [] + + inp = node.inputs[0] + inputParams = [] + inputDims = [] + inputNode, params, dims = self.create_and_add_input_node(inp) + inputParams.append(params) + inputDims.append(dims) + inDims = max(inp.shape.ndims, 1) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + + outputDims = [] + outputList = self.create_and_add_output_node(node) + dims = outputList[0].desc(self.graph).shape + outDims = len(dims) + outputDims.append(self.get_default_dims(node.outputs[0])) + + mapLabel = _string_builder(node.type) + mapParams = [] + outputParams = [[]] + mapRange = [] + mapParams = inputParams[0] + mapRange = inputDims[0] + + # Reshape from 4 to 2 dimensions + if (inDims > outDims): + outputParams[0] = [ + "i0", "i1*" + str(node.inputs[0].shape[2]) + "*" + str( + node.inputs[0].shape[3]) + "+i2*" + str( + node.inputs[0].shape[3]) + "+i3" + ] + # Reshape from 2 to 4 dimensions + elif (inDims < outDims): + outputParams[0] = [ + "i0", "i1/(" + str(node.outputs[0].shape[2]) + "*" + str( + node.outputs[0].shape[3]) + ")", + "(i1%" + "(" + str(node.outputs[0].shape[2]) + "*" + str( + node.outputs[0].shape[3]) + "))/" + str( + node.outputs[0].shape[3]), + "i1%" + str(node.outputs[0].shape[3]) + ] + # If they have the same dimension + else: + outputParams[0] = mapParams + mapRange = outputDims[0] + inputDims[0] = outputDims[0] + + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel, {'j0'}, {'out'}, "out = j0") + self.add_out_memlets(outputList, mapExit, tasklet, outputDims, + outputParams) + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + + def visit_MaxPoolGrad(self, node): + # TODO: Currently only supports 2x2 maxpooling + state = self.state + mapParams = [] + mapRange = [] + outputParams = [] + outputDims = [] + inputParams = [] + inputDims = [] + inputList = [] + inputNodes = [] + + for count, inp in enumerate(node.inputs): + + inputNode, _, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + params = [] + + for ndims, dim in enumerate(inp.shape): + if ((not count == 0) and (ndims == 1 or ndims == 2)): + params.append("i" + str(ndims) + "/2") + + else: + params.append("i" + str(ndims)) + + inputParams.append(params) + inputDims.append(dims) + + outputList = self.create_and_add_output_node(node) + mapLabel = _string_builder(node.type) + + dtype = dace.typeclass(_tensortype(node)) + shape = dace.properties.ShapeProperty.from_string( + str(inputList[0].shape)) + + tempNode = state.add_transient( + _string_builder(node.name + "_tmp"), shape, dtype, toplevel=True) + tempList = [tempNode] + + outputDims = inputDims + outputParams = inputParams + # Copy as we manipulate inputParams but don't want map params/range to + # change + mapParams = inputParams[0].copy() + mapRange = inputDims[0].copy() + + mapEntry, mapExit = state.add_map(mapLabel + "_map1", + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet( + mapLabel + "_map1", {'j0', 'j1', 'j2'}, {'out'}, + "if (j0==j1):\n\tout = j2\nelse:\n\tout = 0") + + self.add_out_memlets(tempList, mapExit, tasklet, outputDims, + outputParams) + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + + # Second map: + # as we don't have the indicies of the maxpooling we need to manually + # figure out which one contributed. If it is ambigious we break the + # tie by the following priority k[i,j]0):\n\tout = j0\nelse:\n\tout = 0") + self.add_out_memlets(outputList, mapExit, tasklet, outputDims, + outputParams) + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + + def visit_BiasAddGrad(self, node): + + state = self.state + inputList = [] + inputNodes = [] + outputList = [] + mapParams = [] + mapRange = [] + outputParams = [] + outputDims = [] + inputParams = [] + inputDims = [] + + for count, inp in enumerate(node.inputs): + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputParams.append(params) + inputDims.append(dims) + + outputList = self.create_and_add_output_node(node) + for out in node.outputs: + outputParams.append([inputParams[0][-1]]) + outputDims.append([inputDims[0][-1]]) + + mapLabel = _string_builder(node.type) + mapParams = inputParams[0] + mapRange = inputDims[0] + + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel, {'j0'}, {'out'}, "out = j0") + self.reinitCR(outputList[0], outputParams, outputDims, "0") + self.add_out_memlets(outputList, mapExit, tasklet, outputDims, + outputParams, 'lambda a,b: a+b', 0) + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + + def visit_Conv2DBackpropInput(self, node): + + inputList = [] + inputNodes = [] + mapParams = [] + outputParams = [] + mapRange = [] + outputDims = [[]] + inputParams = [] + inputDims = [[], []] + strides = node.get_attr("strides")[1] + state = self.state + + for count, inp in enumerate(node.inputs): + if (not count == 0): + inputNode = self.create_and_add_input_node(inp)[0] + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + + outputList = self.create_and_add_output_node(node) + + for i in range(0, 7): + mapParams.append("i" + str(i)) + ndims = len(outputList[0].desc(self.graph).shape) + for i in range(0, ndims): + inputDims[1].append(str(0) + ":" + str(inputList[1].shape[i])) + inputDims[0].append(str(0) + ":" + str(inputList[0].shape[i])) + outputDims[0].append( + str(0) + ":" + str(outputList[0].desc(self.graph).shape[i])) + + ksize = inputList[0].shape[0] + paddedInput, paddedDims = self.inputPadding( + node, inputNodes[1], inputList[1], outputList[0].desc( + self.graph).shape[1], ksize, strides, inputDims[1]) + inputDims[1] = paddedDims + inputList[1] = paddedInput + inputParams.append( + ["-1-i5+" + str(ksize), "-1-i6+" + str(ksize), "i3", "i4"]) + inputParams.append([ + "i0", "i1*" + str(strides) + "+i5", "i2*" + str(strides) + "+i6", + "i4" + ]) + + outputParams.append(["i0", "i1", "i2", "i3"]) + + mapLabel = _string_builder(node.type) + mapParams = ["i0", "i1", "i2", "i3"] + mapParams2 = ["i5", "i6", "i4"] + mapRange = outputDims[0] + mapRange2 = inputDims[0][:-2] + mapRange2.append(inputDims[1][-1]) + mapEntry, mapExit = state.add_map(mapLabel + "_outer", + dict(zip(mapParams, mapRange))) + mapEntry2, mapExit2 = state.add_map(mapLabel + "_inner", + dict(zip(mapParams2, mapRange2))) + + tasklet = state.add_tasklet(mapLabel, {'j0', 'j1'}, {'out'}, + "out = j0 * j1") + self.reinitCR(outputList[0], outputParams, outputDims, "0") + + self.add_out_memlets(outputList, mapExit, mapExit2, outputDims, + outputParams, 'lambda a,b: a+b', 0) + self.add_in_memlets(inputNodes, mapEntry, mapEntry2, inputDims, + inputParams) + for i, inp in enumerate(inputNodes): + name = "j" + str(i) + memlet = Memlet.simple(inp, ",".join(inputParams[i])) + state.add_edge(mapEntry2, None, tasklet, name, memlet) + for i, out in enumerate(outputList): + name = "out" + memlet = Memlet.simple( + out, + ",".join(outputParams[i]), + wcr_str='lambda a,b: a+b', + wcr_identity=0) + state.add_edge(tasklet, name, mapExit2, None, memlet) + + def visit_Conv2DBackpropFilter(self, node): + + state = self.state + inputList = [] + inputNodes = [] + outputList = [] + outputParams = [] + outputDims = [] + inputParams = [] + inputDims = [] + + for count, inp in enumerate(node.inputs): + if (count != 1): + inputNode, _, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputDims.append(dims) + inputParams.append(["i0", "i1+i5", "i2+i6", "i3"]) + inputParams.append(["i0", "i1", "i2", "i4"]) + + outputList = self.create_and_add_output_node(node) + for count, out in enumerate(node.outputs): + params = ["i5", "i6", "i3", "i4"] + dims = self.get_default_dims(out) + outputParams.append(params) + outputDims.append(dims) + + mapParams = outputParams[0] + mapParams2 = inputParams[1][:-1] + mapRange = outputDims[0] + mapRange2 = inputDims[1][:-1] + mapLabel = _string_builder(node.type) + mapEntry, mapExit = state.add_map(mapLabel + "_outer", + dict(zip(mapParams, mapRange))) + mapEntry2, mapExit2 = state.add_map(mapLabel + "_inner", + dict(zip(mapParams2, mapRange2))) + + tasklet = state.add_tasklet(mapLabel, {'j0', 'j1'}, {'out'}, + "out = j0*j1") + + self.reinitCR(outputList[0], outputParams, outputDims, "0") + + self.add_out_memlets(outputList, mapExit, mapExit2, outputDims, + outputParams, 'lambda a,b: a+b', 0) + self.add_in_memlets(inputNodes, mapEntry, mapEntry2, inputDims, + inputParams) + + for i, inp in enumerate(inputNodes): + name = "j" + str(i) + memlet = Memlet.simple(inp, ",".join(inputParams[i])) + state.add_edge(mapEntry2, None, tasklet, name, memlet) + + for i, out in enumerate(outputList): + name = "out" + memlet = Memlet.simple( + out, + ",".join(outputParams[i]), + wcr_str='lambda a,b: a+b', + wcr_identity=0) + state.add_edge(tasklet, name, mapExit2, None, memlet) + + def visit_SparseSoftmaxCrossEntropyWithLogits(self, node): + + state = self.state + inputList = [] + inputNodes = [] + outputList = [] + inputParams = [] + inputDims = [] + + for inp in node.inputs: + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputDims.append(dims) + inputParams.append(params) + + for out in node.outputs: + label = _string_builder(out.name) + try: + outputNode = state.find_node(label) + except (LookupError): + dtype = dace.typeclass(_tensortype(node)) + shape = dace.properties.ShapeProperty.from_string( + str(_tensorshape(out))) + outputNode = state.add_transient( + label, shape, dtype, toplevel=True) + outputList.append(outputNode) + + mapLabel = _string_builder(node.type) + mapParams = inputParams[0] + mapRange = inputDims[0] + + #1st map, get maximum in each batchsize dimension + dtype = dace.typeclass(_tensortype(node)) + shape = dace.properties.ShapeProperty.from_string( + str(inputList[1].shape)) + + temp1Node = state.add_transient( + mapLabel + "_max_tmp", shape, dtype, toplevel=True) + mapEntry, mapExit = state.add_map(mapLabel + "_max", + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel + "_max", {'j0'}, {'out'}, + "out = j0") + self.reinitCR(temp1Node, [inputParams[1]], [inputDims[1]], + "-999999999999") + self.add_in_memlets([inputNodes[0]], mapEntry, tasklet, [inputDims[0]], + [inputParams[0]]) + self.add_out_memlets([temp1Node], mapExit, tasklet, [inputDims[1]], + [inputParams[1]], 'lambda a,b: max(a,b)', + -9999999999) + + # 2nd map, calculate the denominator sum + temp2Node = state.add_transient( + mapLabel + "_denominator_tmp", shape, dtype, toplevel=True) + mapEntry, mapExit = state.add_map(mapLabel + "_denominator", + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet( + mapLabel + "_denominator", {'j0', 'j1'}, {'out'}, + "out = dace::math::exp(j0-j1);", + language=dace.types.Language.CPP) + self.reinitCR(temp2Node, [inputParams[1]], [inputDims[1]], "0") + inList = [inputNodes[0], temp1Node] + self.add_in_memlets(inList, mapEntry, tasklet, inputDims, inputParams) + self.add_out_memlets([temp2Node], mapExit, tasklet, [inputDims[1]], + [inputParams[1]], 'lambda a,b: a+b', 0) + + # 3rd map, calculate the sofmax + shape = dace.properties.ShapeProperty.from_string( + str(inputList[0].shape)) + temp3Node = state.add_transient( + mapLabel + "_softmax_tmp", shape, dtype, toplevel=True) + mapEntry, mapExit = state.add_map(mapLabel + "_softmax", + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet( + mapLabel + "_softmax", {'j0', 'j1', 'j2'}, {'out'}, + "out = (dace::math::exp(j0-j1))/j2;", + language=dace.types.Language.CPP) + inList = [inputNodes[0], temp1Node, temp2Node] + paramsList = inputParams + [inputParams[1]] + dimsList = inputDims + [inputDims[1]] + self.add_in_memlets(inList, mapEntry, tasklet, dimsList, paramsList) + self.add_out_memlets([temp3Node], mapExit, tasklet, [inputDims[0]], + [inputParams[0]]) + + # 4th map, calculate the cross-entropy loss for an optional loss output + mapEntry, mapExit = state.add_map(mapLabel + "_loss", + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet( + mapLabel + "_loss", {'j0', 'j1'}, {'out'}, + "if (int(j1) == i1) {\n\tout=-(dace::math::log(j0));}\nelse{\n\tout=0;}", + language=dace.types.Language.CPP) + self.reinitCR(outputList[0], [inputParams[1]], [inputDims[1]], "0") + self.add_in_memlets([temp3Node, inputNodes[1]], mapEntry, tasklet, + inputDims, inputParams) + self.add_out_memlets([outputList[0]], mapExit, tasklet, [inputDims[1]], + [inputParams[1]], 'lambda a,b: a+b', 0) + + # 5th map, gradient of the whole layer + mapEntry, mapExit = state.add_map(mapLabel + "_gradient", + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet( + mapLabel + "_gradient", {'j0', 'j1'}, {'out'}, + "if(int(j1)==i1):\n\tout = j0-1\nelse:\n\tout = j0") + self.add_out_memlets([outputList[1]], mapExit, tasklet, [inputDims[0]], + [inputParams[0]]) + self.add_in_memlets([temp3Node, inputNodes[1]], mapEntry, tasklet, + inputDims, inputParams) + + def visit_Identity(self, node): + + state = self.state + inputList = [] + inputNodes = [] + outputList = [] + inputParams = [] + inputDims = [] + + # Create input node and its params + for count, inp in enumerate(node.inputs): + if (count == 0): + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputParams.append(params) + inputDims.append(dims) + + outputList = self.create_and_add_output_node(node) + memlet = Memlet.simple(inputNodes[0], ",".join(inputDims[0])) + state.add_edge(inputNodes[0], None, outputList[0], None, memlet) + + def visit_LRNGrad(self, node): + + inputList = [] + inputNodes = [] + outputList = [] + state = self.state + + alpha = str(node.get_attr("alpha")) + beta = str(node.get_attr("beta")) + bias = str(node.get_attr("bias")) + depth_radius = str(node.get_attr("depth_radius")) + + for count, inp in enumerate(node.inputs): + inputNode = self.create_and_add_input_node(inp)[0] + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + if (count == 0): + shortDims = [] + shortAccesses = [] + for dim in inp.shape: + shortDims.append("0:" + str(dim)) + shortAccesses.append(str(dim)) + longDims = [] + longDims = shortDims + ["0:" + depth_radius + "*2+1"] + paddedDims = [] + paddedDims += shortDims + paddedDims[-1] += "+" + depth_radius + "*2" + + label = _string_builder(node.name) + outputList = self.create_and_add_output_node(node) + longParams = ["i0", "i1", "i2", "i3", "i4"] + shortParams = ["i0", "i1", "i2", "i3"] + copyParams = ["i0", "i1", "i2", "i3+" + depth_radius] + normParams = ["i0", "i1", "i2", "i3+i4"] + + paddedShape = [] + paddedShape += shortAccesses + paddedShape[-1] += "+" + depth_radius + paddedInput = state.add_transient( + label + "_paddedInput", + paddedShape, + dace.typeclass(_tensortype(node)), + toplevel=True) + mapEntry, mapExit = state.add_map(label + "_padding", + dict(zip(shortParams, shortDims))) + tasklet = state.add_tasklet(label + "_padding", {'j0'}, {'out'}, + "out=j0") + self.add_in_memlets([inputNodes[2]], mapEntry, tasklet, [shortDims], + [shortParams]) + self.add_out_memlets([paddedInput], mapExit, tasklet, [paddedDims], + [copyParams]) + + sqrsum = state.add_transient( + label + "_Sqrsum", shortAccesses, _tensortype(node), toplevel=True) + mapEntry, mapExit = state.add_map(label + "_sqrsum", + dict(zip(longParams, longDims))) + tasklet = state.add_tasklet(label + "_sqrsum", {'j0'}, {'out'}, + "out=j0*j0") + self.reinitCR(sqrsum, [shortParams], [shortDims], "0") + self.add_in_memlets([paddedInput], mapEntry, tasklet, [paddedDims], + [normParams]) + self.add_out_memlets([sqrsum], mapExit, tasklet, [shortDims], + [shortParams], 'lambda a,b: a+b', 0) + + label = _string_builder(node.name) + norm = state.add_transient( + label + "_Norm", shortAccesses, _tensortype(node), toplevel=True) + mapEntry, mapExit = state.add_map(label + "_norm", + dict(zip(shortParams, shortDims))) + tasklet = state.add_tasklet(label + "_norm", {'j0'}, {'out'}, + "out=" + alpha + "*j0+" + bias) + self.add_in_memlets([sqrsum], mapEntry, tasklet, [shortDims], + [shortParams]) + self.add_out_memlets([norm], mapExit, tasklet, [shortDims], + [shortParams]) + + preOut = state.add_transient( + label + "_preOut", shortAccesses, _tensortype(node), toplevel=True) + mapEntry, mapExit = state.add_map(label, dict( + zip(longParams, longDims))) + taskletCode = "if (i4==" + depth_radius + "){\n out = pow(j2," + beta + ")-2*" + alpha + "*" + beta + "*j1*j0/j2;}\n else{\n out = -2*" + alpha + "*" + beta + "*j1*j0/j2;}" + tasklet = state.add_tasklet( + label, {'j0', 'j1', 'j2'}, {'out'}, + taskletCode, + language=dace.types.Language.CPP) + self.reinitCR(preOut, [shortParams], [shortDims], "0") + inList = [inputNodes[1]] + inList.append(paddedInput) + inList.append(norm) + self.add_in_memlets(inList, mapEntry, tasklet, + [shortDims, paddedDims, shortDims], + [shortParams, normParams, shortParams]) + self.add_out_memlets([preOut], mapExit, tasklet, [shortDims], + [shortParams], 'lambda a,b: a+b', 0) + + mapEntry, mapExit = state.add_map(label + "_out", + dict(zip(shortParams, shortDims))) + tasklet = state.add_tasklet(label + "_out", {'j0', 'j1'}, {'out'}, + "out=j0*j1") + self.add_in_memlets([inputNodes[0], preOut], mapEntry, tasklet, + [shortDims, shortDims], [shortParams, shortParams]) + self.add_out_memlets(outputList, mapExit, tasklet, [shortDims], + [shortParams]) + + def visit_LRN(self, node): + + inputList = [] + inputNodes = [] + outputList = [] + state = self.state + alpha = str(node.get_attr("alpha")) + beta = str(node.get_attr("beta")) + bias = str(node.get_attr("bias")) + depth_radius = str(node.get_attr("depth_radius")) + + for count, inp in enumerate(node.inputs): + inputNode = self.create_and_add_input_node(inp)[0] + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + if (count == 0): + shortDims = [] + shortAccesses = [] + for dim in inp.shape: + shortDims.append("0:" + str(dim)) + shortAccesses.append(str(dim)) + longDims = [] + longDims = shortDims + ["0:" + depth_radius + "*2+1"] + paddedDims = [] + paddedDims += shortDims + paddedDims[-1] += "+" + depth_radius + "*2" + + label = _string_builder(node.name) + outputList = self.create_and_add_output_node(node) + longParams = ["i0", "i1", "i2", "i3", "i4"] + shortParams = ["i0", "i1", "i2", "i3"] + copyParams = ["i0", "i1", "i2", "i3+" + depth_radius] + normParams = ["i0", "i1", "i2", "i3+i4"] + + paddedShape = [] + paddedShape += shortAccesses + paddedShape[-1] += "+" + depth_radius + paddedInput = state.add_transient( + label + "_paddedInput", + paddedShape, + dace.typeclass(_tensortype(node)), + toplevel=True) + mapEntry, mapExit = state.add_map(label + "_padding", + dict(zip(shortParams, shortDims))) + tasklet = state.add_tasklet(label + "_padding", {'j0'}, {'out'}, + "out=j0") + self.add_in_memlets([inputNodes[0]], mapEntry, tasklet, [shortDims], + [shortParams]) + self.add_out_memlets([paddedInput], mapExit, tasklet, [paddedDims], + [copyParams]) + + sqrsum = state.add_transient( + label + "_Sqrsum", shortAccesses, _tensortype(node), toplevel=True) + mapEntry, mapExit = state.add_map(label + "_sqrsum", + dict(zip(longParams, longDims))) + tasklet = state.add_tasklet(label + "_sqrsum", {'j0'}, {'out'}, + "out=j0*j0") + self.reinitCR(sqrsum, [shortParams], [shortDims], "0") + self.add_in_memlets([paddedInput], mapEntry, tasklet, [paddedDims], + [normParams]) + self.add_out_memlets([sqrsum], mapExit, tasklet, [shortDims], + [shortParams], 'lambda a,b: a+b', 0) + + mapEntry, mapExit = state.add_map(label, + dict(zip(shortParams, shortDims))) + tasklet = state.add_tasklet( + _string_builder(node.name), {'j0', 'j1'}, {'out'}, + "out = j0/(pow(" + bias + "+" + alpha + "*j1," + beta + "));", + language=dace.types.Language.CPP) + self.add_in_memlets((inputNodes + [sqrsum]), mapEntry, tasklet, + [shortDims, shortDims], [shortParams, shortParams]) + self.add_out_memlets(outputList, mapExit, tasklet, [shortDims], + [shortParams]) + + def visit_ArgMax(self, node): + + state = self.state + inputList = [] + inputNodes = [] + + for count, inp in enumerate(node.inputs): + if (count == 0): + inputNode = self.create_and_add_input_node(inp)[0] + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + + inputAccesses = [[], []] + inputDims = [[], []] + inputParams = [[], []] + for i, dim in enumerate(inp.shape): + if (i == 0): + inputAccesses[1].append(str(dim)) + inputParams[1].append("i" + str(i)) + inputDims[1].append("0:" + str(dim)) + inputAccesses[0].append(str(dim)) + inputParams[0].append("i" + str(i)) + inputDims[0].append("0:" + str(dim)) + + outputList = self.create_and_add_output_node(node) + + mapLabel = _string_builder(node.name) + mapEntry, mapExit = state.add_map( + mapLabel + "_max", dict(zip(inputParams[0], inputDims[0]))) + dtype = dace.typeclass(_tensortype(node)) + shape = dace.properties.ShapeProperty.from_string(",".join( + inputAccesses[1])) + temp1Node = state.add_transient( + mapLabel + "_max_tmp", shape, dtype, toplevel=True) + + tasklet = state.add_tasklet(mapLabel + "_max", {'j0'}, {'out'}, + "out = j0") + self.reinitCR(temp1Node, [inputParams[1]], [inputDims[1]], + "-999999999999") + self.add_in_memlets([inputNodes[0]], mapEntry, tasklet, [inputDims[0]], + [inputParams[0]]) + self.add_out_memlets([temp1Node], mapExit, tasklet, [inputDims[1]], + [inputParams[1]], 'lambda a,b: max(a,b)', + -999999999999) + + mapEntry, mapExit = state.add_map( + mapLabel + "_arg", dict(zip(inputParams[0], inputDims[0]))) + outputNode = outputList[0] + tasklet = state.add_tasklet(mapLabel + "_map2", {'j0', 'j1'}, {'out'}, + "if (j0==j1):\n\tout=i1") + self.add_in_memlets([inputNodes[0], temp1Node], mapEntry, tasklet, + inputDims, inputParams) + self.add_out_memlets([outputNode], mapExit, tasklet, [inputDims[1]], + [inputParams[1]]) + + def visit_Cast(self, node): + + state = self.state + inputList = [] + inputNodes = [] + outputList = [] + mapParams = [] + mapRange = [] + outputParams = [] + outputDims = [] + inputParams = [] + inputDims = [] + castType = None + + dtype = node.get_attr("DstT") + if dtype.as_numpy_dtype == object: + raise NotImplementedError( + 'Type %s is not a valid numpy type' % str(dtype)) + castType = dace.typeclass(dtype.as_numpy_dtype).ctype + + for count, inp in enumerate(node.inputs): + if (count == 0): + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputParams.append(params) + inputDims.append(dims) + + outputList = self.create_and_add_output_node(node) + for out in node.outputs: + params = self.get_default_params(out) + dims = self.get_default_dims(out) + outputParams.append(params) + outputDims.append(dims) + + mapLabel = _string_builder(node.type) + mapParams = inputParams[0] + mapRange = inputDims[0] + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel, {'j0'}, {'out'}, + "out = " + castType + "(j0)") + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + self.add_out_memlets(outputList, mapExit, tasklet, outputDims, + outputParams) + + def visit_Print(self, node): + inputList = [] + inputNodes = [] + outputList = [] + state = self.state + mapParams = [] + mapRange = [] + outputParams = [] + outputDims = [] + inputParams = [] + inputDims = [] + + for count, inp in enumerate(node.inputs): + if (count == 0): + inputNode, params, dims = self.create_and_add_input_node(inp) + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + inputParams.append(params) + inputDims.append(dims) + + outputList = self.create_and_add_output_node(node) + for out in node.outputs: + params = self.get_default_params(out) + dims = self.get_default_dims(out) + outputParams.append(params) + outputDims.append(dims) + + mapLabel = _string_builder(node.type) + mapParams = inputParams[0] + mapRange = inputDims[0] + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + + ifClause = "if (" + for param in mapParams: + ifClause += param + "==1 and " + + ifClause = ifClause[:-4] + "):" + taskletCode = "out = j0\n" + ifClause + "\n\tprintf(\"" + inputList[0].label + "\")\n" + taskletCode = "out = j0\nif(True):\n\tprintf(\"%f\\n\",out)" + tasklet = state.add_tasklet(mapLabel, {'j0'}, {'out'}, taskletCode) + self.add_out_memlets(outputList, mapExit, tasklet, outputDims, + outputParams) + self.add_in_memlets(inputNodes, mapEntry, tasklet, inputDims, + inputParams) + + def visit_Softmax(self, node): + + inputList = [] + inputNodes = [] + state = self.state + + for inp in node.inputs: + label = _string_builder(inp.name) + try: + inputNode = state.find_node(label) + except (LookupError): + inputNode = self.create_and_add_input_node(inp)[0] + inputList.append(inputNode.desc(self.graph)) + inputNodes.append(inputNode) + + outputList = self.create_and_add_output_node(node) + + inputDims = [[], []] + inputParams = [[], []] + + for i, dim in enumerate(inp.shape): + if (i == 0): + inputParams[1].append("i" + str(i)) + inputDims[1].append("0:" + str(dim)) + inputParams[0].append("i" + str(i)) + inputDims[0].append("0:" + str(dim)) + + mapLabel = _string_builder(node.name) + mapEntry, mapExit = state.add_map( + mapLabel + "_map1", dict(zip(inputParams[0], inputDims[0]))) + mapParams = inputParams[0] + mapRange = inputDims[0] + + # 1st map, get maximum in each batchsize dimension + dtype = dace.typeclass(_tensortype(node)) + shape = dace.properties.ShapeProperty.from_string( + str(node.inputs[0].shape.dims[0])) + temp1Node = state.add_transient( + mapLabel + "_max_tmp", shape, dtype, toplevel=True) + mapEntry, mapExit = state.add_map(mapLabel + "_max", + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel + "_max", {'j0'}, {'out'}, + "out = j0") + self.reinitCR(temp1Node, [inputParams[1]], [inputDims[1]], + "-999999999999") + self.add_in_memlets([inputNodes[0]], mapEntry, tasklet, [inputDims[0]], + [inputParams[0]]) + self.add_out_memlets([temp1Node], mapExit, tasklet, [inputDims[1]], + [inputParams[1]], 'lambda a,b: max(a,b)', + -999999999999) + + # 2nd map, calculate the denominator sum + temp2Node = state.add_transient( + mapLabel + "_denominator_tmp", shape, dtype, toplevel=True) + mapEntry, mapExit = state.add_map(mapLabel + "_denominator", + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet( + mapLabel + "_denominator", {'j0', 'j1'}, {'out'}, + "out = dace::math::exp(j0-j1);", + language=dace.types.Language.CPP) + self.reinitCR(temp2Node, [inputParams[1]], [inputDims[1]], "0") + inList = [inputNodes[0], temp1Node] + self.add_in_memlets(inList, mapEntry, tasklet, inputDims, inputParams) + self.add_out_memlets([temp2Node], mapExit, tasklet, [inputDims[1]], + [inputParams[1]], 'lambda a,b: a+b', 0) + + # 3rd map, calculate the sofmax + mapEntry, mapExit = state.add_map(mapLabel + "_softmax", + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet( + mapLabel + "_softmax", {'j0', 'j1', 'out'}, {'out'}, + "out = (dace::math::exp(j0-j1))/j2;", + language=dace.types.Language.CPP) + inList = [inputList[0], temp1Node, temp2Node] + paramsList = inputParams + [inputParams[1]] + dimsList = inputDims + [inputDims[1]] + self.add_in_memlets(inList, mapEntry, tasklet, dimsList, paramsList) + self.add_out_memlets(outputList, mapExit, tasklet, [inputDims[0]], + [inputParams[0]]) + + def add_in_memlets(self, inputList, otherNode, tasklet, inputDims, + inputParams): + """ Convenience function that adds two memlets for each input of the + node: external and internal to a given map. + @param inputList: list of inputNodes (DaCe access node) + @param otherNode: DaCe node (mostly map_entry) + @param tasklet: Normally a tasklet node, but it can also be another + mapEntry, for example map in map. + @param inputDims: List of list of strings dimension of the + respective input. Example: + [["0:5","0:7"],["0:2","0:4"]] + @param inputParams: List of list of strings params of respective + input. Example: [["i0","i1"],["i2","i3"]] + """ + state = self.state + connected_nodes = set() + for i, inp in enumerate(inputList): + assert isinstance(inputDims[i], list) + if inp.data not in connected_nodes: + outerMemlet = Memlet.simple(inp, ",".join(inputDims[i])) + state.add_edge(inp, None, otherNode, None, outerMemlet) + connected_nodes.add(inp.data) + name = "j" + str(i) + innerMemlet = Memlet.simple(inp, ",".join(inputParams[i])) + + if isinstance(tasklet, (Tasklet, NestedSDFG)): + state.add_edge(otherNode, None, tasklet, name, innerMemlet) + else: + state.add_edge(otherNode, None, tasklet, None, innerMemlet) + + def add_out_memlets(self, + outputList, + otherNode, + tasklet, + outputDims, + outputParams, + wcr=None, + wcr_identity=None): + """ Convenience function that adds two memlets for each output of the + node: external and internal to a given map. + @param outputList: list of outputNodes (DaCe access node) + @param otherNode: DaCe node (mostly map_entry) + @param tasklet: Normally a tasklet node, but it can also be another + mapEntry, for example map in map. + @param outputDims: List of list of strings dimension of the + respective output. Example: + [["0:5","0:7"],["0:2","0:4"]] + @param outputParams: List of list of strings params of respective + output. Example: [["i0","i1"],["i2","i3"]] + @param wcr: (optional) Write-conflict resolution function (as + string). + @param wcr_identity: (optional) Identity element for write-conflict + resolution. + """ + + connected_nodes = set() + + state = self.state + for i, out in enumerate(outputList): + assert isinstance(outputDims[i], list) + if (len(outputList) > 1): + name = "out" + str(i) + else: + name = "out" + + if out.data not in connected_nodes: + outerMemlet = Memlet.simple( + out, + ",".join(outputDims[i]), + wcr_str=wcr, + wcr_identity=wcr_identity) + state.add_edge(otherNode, None, out, None, outerMemlet) + connected_nodes.add(out.data) + innerMemlet = Memlet.simple( + out, + ",".join(outputParams[i]), + wcr_str=wcr, + wcr_identity=wcr_identity) + + if isinstance(tasklet, (Tasklet, NestedSDFG)): + state.add_edge(tasklet, name, otherNode, None, innerMemlet) + else: + state.add_edge(tasklet, None, otherNode, None, innerMemlet) + + def create_and_add_input_node(self, inp): + """ Creates a DaCe access node for each input of `inp`, adds it to the + state, and returns it. + If the node already exists, returns the pre-existing node. + @param inp: tf.Operation + @return: A 3-tuple of (input DaCe access node, + list of parameter strings, + list of dimension strings). + """ + + state = self.state + # Get DaCe name of the operation + label = _string_builder(inp.name) + # Try to find node in DaCe graph + try: + # If successful, use the existing node + inputNode = state.find_node(label) + except (LookupError): + # Get type and shape of the input tensor + dtype = dace.typeclass(_tensortype(inp)) + shape = dace.properties.ShapeProperty.from_string( + str(_tensorshape(inp))) + # Create and add array, default is transient, toplevel =True + inputNode = state.add_transient( + name=label, shape=shape, dtype=dtype, toplevel=True) + + params = self.get_default_params(inp) + dims = self.get_default_dims(inp) + + return inputNode, params, dims + + def create_and_add_output_node(self, node): + """ Creates a DaCe access node for each output of `node`, adds it to + the state, and returns it. + If the node already exists, returns the pre-existing node. + @param node: tf.Operation + @return: List of DaCe access node. + """ + outputList = [] + state = self.state + # Iterate over all output nodes + for count, out in enumerate(node.outputs): + label = _string_builder(out.name) + # Try to find node in DaCe graph + try: + # If successful, use the existing node + outputNode = state.find_node(label) + except (LookupError): + # Get type and shape of the tensor + dtype = dace.typeclass(_tensortype(out)) + shape = dace.properties.ShapeProperty.from_string( + str(_tensorshape(out))) + outputNode = state.add_transient( + label, shape, dtype, toplevel=True) + outputList.append(outputNode) + return outputList + + def reinitCR(self, inp, params, dims, identity): + """ Adds a reinitialization map to a `reinit` state, setting inputs + to their initial values. Only used in training mode. + @param inp: DaCe access node. + @param params: List of string parameters to `inp`. + @param dims: List of strings dimensions of `inp`. + @param identity: Identity value of the CR node (as a string) + """ + + if self.training: + # Swap current state and reinitState + self.state, self.reinitState = self.reinitState, self.state + node = inp + state = self.state + dtype = node.desc(self.graph).dtype + label = node.label + + # Mark node as non-transient as we need to set it from the outside + # the SDFG. + node.desc(self.graph).transient = False + + shape = dace.properties.ShapeProperty.from_string( + str(inp.desc(self.graph).shape)) + # Add input, output and map to reinitState + inputNode = state.add_array(label, shape, dtype) + outputNode = state.add_array(label, shape, dtype) + mapEntry, mapExit = state.add_map(label, + dict(zip(params[0], dims[0]))) + + # Output is set to identity + tasklet = state.add_tasklet(label, set(), {'out'}, + "out = " + identity) + state.add_edge(mapEntry, None, tasklet, None, EmptyMemlet()) + self.add_out_memlets([outputNode], mapExit, tasklet, dims, params) + # Add numpy array with identity value to the reinit dict. + npArray = np.full(shape, int(identity)).astype( + node.desc(self.graph).dtype.type) + self.reinitDict.update({label: npArray}) + # Swap state back + self.reinitState, self.state = self.state, self.reinitState + else: + pass + + def inputPadding(self, node, inpnode, inp, outputSize, kernelSize, strides, + inputDims): + """ Zero-pads the input to fit the outputSize. + @param node: tf.Operation + @param inpnode: DaCe access node to pad + @param outputSize: Output size. + @param kernelSize: Kernel size. + @param strides: Strides. + @param inputDims: List of strings (e.g.["0:N","0:M"]). + @return: A 2-tuple (output DaCe access node with padded input, + list of dimension strings of the padded data). + """ + state = self.state + paddingUp = 0 + paddingDown = 0 + label = inpnode.label + inputSize = inp.shape[1] + # Calculate padding according to paper + padding = strides * (outputSize - 1) + kernelSize - inputSize + # If padding is even (padding is on each side the same) + if (padding % 2 == 0): + paddingUp = padding // 2 + paddingDown = padding // 2 + # If padding is uneven, we pad more on the bottom and on the right side + # of an image (matching TensorFlow behavior) + else: + paddingUp = padding // 2 + paddingDown = paddingUp + 1 + + # Set up the different padding dimensions, accesses and params. + outputDims = inputDims.copy() + outputDims[1] = str(paddingUp) + ":" + str( + inp.shape[1]) + "+" + str(paddingUp) + outputDims[2] = str(paddingUp) + ":" + str( + inp.shape[2]) + "+" + str(paddingUp) + outputAccesses = list(map(str, list(inp.shape))) + outputAccesses[1] += "+" + str(paddingUp) + "+" + str(paddingDown) + outputAccesses[2] += "+" + str(paddingUp) + "+" + str(paddingDown) + outputDims = [] + inputParams = [] + for i, dim in enumerate(outputAccesses): + inputParams.append("i" + str(i)) + outputDims.append("0:" + dim) + + outputParams = inputParams.copy() + outputParams[1] += "+" + str(paddingUp) + outputParams[2] += "+" + str(paddingUp) + + # Add the padded input to the graph, set it to zero, and add the map. + shape = dace.properties.ShapeProperty.from_string( + ",".join(outputAccesses)) + output = state.add_transient( + label + "_padded", shape=shape, dtype=inp.dtype, toplevel=True) + output.desc(self.graph).setzero = True + + mapParams = inputParams + mapRange = inputDims + mapLabel = _string_builder(node.type) + mapEntry, mapExit = state.add_map(mapLabel, + dict(zip(mapParams, mapRange))) + tasklet = state.add_tasklet(mapLabel, {'j0'}, {'out'}, "out = j0") + self.add_in_memlets([inpnode], mapEntry, tasklet, [inputDims], + [inputParams]) + self.add_out_memlets([output], mapExit, tasklet, [outputDims], + [outputParams]) + return output, outputDims + + def get_default_params(self, tensor, start=0): + """ Returns the default parameters of a tensor starting at `start`, + e.g., ["i0","i1",...]. + @param tensor: tf.Tensor. + @param start: Starting position for the iteration. + @return: List of parameters as strings ["i0",i"1",...]. + """ + params = [] + shape = _tensorshape(tensor) + if shape == 1: + shape = [1] + for i, dim in enumerate(shape, start): + params.append("i" + str(i)) + return params + + def get_default_dims(self, tensor): + """ Returns the default dimensions of a tensor e.g., ["0:N","0:M"] + @param tensor: tf.Tensor. + @return: List of dimensions as strings ["0:N","0:M"] + """ + dims = [] + shape = _tensorshape(tensor) + if shape == 1: + shape = [1] + for dim in shape: + dims.append("0:" + str(dim)) + return dims diff --git a/dace/graph/__init__.py b/dace/graph/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dace/graph/dot.py b/dace/graph/dot.py new file mode 100644 index 0000000000..d7839dba05 --- /dev/null +++ b/dace/graph/dot.py @@ -0,0 +1,201 @@ +import copy +import html +from dace import data, memlet +from dace.graph import graph as gr, edges + + +def draw_edge_explicit(srcName, dstName, edge, sdfg, graph, **extraOpts): + opts = {} + if isinstance(edge.data, memlet.Memlet): + if getattr(edge.data, '__label__', False): + opts["label"] = edge.data.__label__(sdfg, graph) + else: + opts["label"] = str(edge.data) + if edge.data.wcr is not None: + opts['style'] = 'dashed' + elif isinstance(edge.data, edges.InterstateEdge): + opts.update(edge.data.dotOpts) + # Unhandled properties + elif edge.data != None: + raise ValueError("Unhandled edge: " + str(edge.data)) + if extraOpts: + opts.update(extraOpts) # Custom options will overwrite default + + if isinstance(edge, gr.MultiConnectorEdge): + sconn = '' if edge.src_conn is None else (':' + edge.src_conn) + dconn = '' if edge.dst_conn is None else (':' + edge.dst_conn) + else: + sconn = '' + dconn = '' + + return ("\"{}\"{sconn} -> \"{}\"{dconn}".format( + srcName, dstName, sconn=sconn, dconn=dconn) + ((" [" + ", ".join( + ["{}=\"{}\"".format(key, value) + for key, value in opts.items()]) + "];") if opts else ";")) + + +def draw_edge(sdfg, graph, edge, **extraOpts): + srcName = 's%d_%d' % (sdfg.node_id(graph), graph.node_id(edge.src)) + dstName = 's%d_%d' % (sdfg.node_id(graph), graph.node_id(edge.dst)) + + return draw_edge_explicit(srcName, dstName, edge, sdfg, graph) + + +def draw_interstate_edge(sdfg, src_graph, dst_graph, edge, **extraOpts): + srcName = 's%d_%d' % (sdfg.node_id(src_graph), src_graph.node_id(edge.src)) + dstName = 's%d_%d' % (sdfg.node_id(dst_graph), dst_graph.node_id(edge.dst)) + if isinstance(edge, gr.MultiConnectorEdge): + if edge.src_conn is not None: + srcName += '@' + edge.src_conn + if edge.dst_conn is not None: + dstName += '@' + edge.dst_conn + + return draw_edge_explicit(srcName, dstName, edge, sdfg, src_graph, + **extraOpts) + + +def draw_interstate_edge_by_name(srcName, dstName, edge, sdfg, src_graph, + **extraOpts): + return draw_edge_explicit(srcName, dstName, edge, sdfg, src_graph, + **extraOpts) + + +def draw_node(sdfg, graph, obj, **kwargs): + name = 's%d_%d' % (sdfg.node_id(graph), graph.node_id(obj)) + if getattr(obj, '__label__', False): + opts = {"label": obj.__label__(sdfg, graph)} + else: + opts = {"label": str(obj)} + opts.update(kwargs) + opts["label"] = "\"{}\"".format(opts["label"]) + + if 'fillcolor' not in opts: + opts['fillcolor'] = '"#ffffff"' + if 'style' not in opts: + opts['style'] = 'filled' + else: + opts['style'] = '"filled,%s"' % opts['style'] + + ############################################ + if getattr(obj, 'in_connectors', False) != False and len( + obj.in_connectors) + len(obj.out_connectors) > 0: + # Header + code = '{name} [label=<' + code = code.format(name=name) + # Input connectors + code += '' + code += '' + connector_code = [] + for conn in sorted(obj.in_connectors): + connector_code.append( + '{conn}'. + format(conn=conn)) + code += ''.join(connector_code) + code += '' + + # Contents + html_label = html.escape(opts['label'][1:-1]) + code += '{label}'.format( + label=html_label) + + # Output connectors + code += '' + code += '' + connector_code = [] + for conn in sorted(obj.out_connectors): + connector_code.append( + '{conn}'. + format(conn=conn)) + code += ''.join(connector_code) + code += '' + + # Footer + code += '>' + + filtered_opts = {k: v for k, v in opts.items() if k != 'label'} + if len(filtered_opts.items()) > 0: + ostr = ", ".join([ + str(key) + "=" + str(val) + for key, val in filtered_opts.items() + ]) + code += ', ' + ostr + code += '];\n' + + return code + ############################################ + + return "\"{}\" [{}];".format( + name, + ", ".join([str(key) + "=" + str(val) for key, val in opts.items()])) + + +def draw_invisible_node(name, **kwargs): + opts = dict(label='\"\"', style="invisible") + opts.update(kwargs) + return "\"{}\" [{}];".format( + name, + ", ".join([str(key) + "=" + str(val) for key, val in opts.items()])) + + +def draw_graph(sdfg, graph, standalone=True): + """ Creates a graphviz dot file from a networkx graph input. + + If standalone is set, return a full dot string including header and footer. + """ + state_id = sdfg.node_id(graph) + sdfg = copy.deepcopy(sdfg) + graph = sdfg.nodes()[state_id] + + sdict = graph.scope_dict() + sdict_children = graph.scope_dict(True) + + # Omit collapsed nodes out of nodes to draw + def is_collapsed(node): + scope = sdict[node] + while scope is not None: + if scope.is_collapsed: + return True + scope = sdict[scope] + return False + + nodes_to_draw = set( + node for node in graph.nodes() if not is_collapsed(node)) + + # Collect edges to draw for collapsed nodes (we also need edges coming out of scope exits) + nodes_for_edges = set() + nodes_for_edges.update(nodes_to_draw) + + def add_exit_nodes(scope): + for node in sdict_children[scope]: + if node in sdict_children and node.is_collapsed: + nodes_for_edges.add(graph.exit_nodes(node)[0]) + elif node in sdict_children: + add_exit_nodes(node) + + add_exit_nodes(None) + + edges_to_draw = set( + e for e in graph.edges() + if e.src in nodes_for_edges and e.dst in nodes_for_edges) + + # Take care of scope entry connectors + for node in nodes_to_draw: + if node in sdict_children and node.is_collapsed: + node._out_connectors.clear() + + # Take care of scope exit edges and connectors + for e in edges_to_draw: + if e.src in nodes_for_edges and e.src not in nodes_to_draw: + newsrc = sdict[e.src] + if newsrc is None: + continue + e._src = newsrc + newsrc._out_connectors.add(e.src_conn) + + nodes = [x.draw_node(sdfg, graph) for x in nodes_to_draw] + edges = [draw_edge(sdfg, graph, e) for e in edges_to_draw] + + if not standalone: + return nodes, edges + + return "digraph DaCe {{\n {}\n}}".format("\n ".join(nodes + edges)) diff --git a/dace/graph/edges.py b/dace/graph/edges.py new file mode 100644 index 0000000000..67faa234c9 --- /dev/null +++ b/dace/graph/edges.py @@ -0,0 +1,285 @@ +import ast +import copy +import enum +import re + +import dace +from dace import types +from dace.graph.graph import Edge +from dace.frontend.python import astutils +from dace.properties import Property, CodeProperty, make_properties + + +def assignments_from_string(astr): + """ Returns a dictionary of assignments from a semicolon-delimited + string of expressions. """ + + result = {} + for aitem in astr.split(';'): + aitem = aitem.strip() + m = re.search(r'([^=\s]+)\s*=\s*([^=]+)', aitem) + result[m.group(1)] = m.group(2) + + return result + + +def assignments_to_string(assdict): + """ Returns a semicolon-delimited string from a dictionary of assignment + expressions. """ + return '; '.join(['%s=%s' % (k, v) for k, v in assdict.items()]) + + +@make_properties +class InterstateEdge(object): + """ An SDFG state machine edge. These edges can contain a condition + (which may include data accesses for data-dependent decisions) and + zero or more assignments of values to inter-state variables (e.g., + loop iterates). + """ + + assignments = Property( + dtype=dict, + desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')", + from_string=assignments_from_string, + to_string=assignments_to_string) + condition = CodeProperty(desc="Transition condition") + language = Property(enum=types.Language, default=types.Language.Python) + + def __init__(self, condition=None, assignments=None): + + if condition is None: + condition = ast.parse("1").body[0] + + if assignments is None: + assignments = {} + + self.condition = condition + self.assignments = assignments + + self._dotOpts = {"minlen": 3, "color": "blue", "fontcolor": "blue"} + + def is_unconditional(self): + """ Returns True if the state transition is unconditional. """ + return (self.condition == None or InterstateEdge.condition.to_string( + self.condition).strip() == "1") + + def condition_sympy(self): + cond_ast = self.condition + return symbolic.pystr_to_symbolic(astutils.unparse(cond_ast)) + + def condition_symbols(self): + return dace.symbolic.symbols_in_ast(self.condition[0]) + + def toJSON(self, indent=0): + json = str(self.label) + # get rid of newlines (why are they there in the first place?) + json = re.sub(r"\n", " ", json) + return "\"" + json + "\"" + + @property + def label(self): + assignments = ','.join( + ['%s=%s' % (k, v) for k, v in self.assignments.items()]) + + # Edge with assigment only (no condition) + if astutils.unparse(self.condition) == '1': + # Edge without conditions or assignments + if len(self.assignments) == 0: + return '' + return assignments + + # Edge with condition only (no assignment) + if len(self.assignments) == 0: + return astutils.unparse(self.condition) + + # Edges with assigments and conditions + return assignments + '; ' + astutils.unparse(self.condition) + + @property + def dotOpts(self): + result = {} + result.update(self._dotOpts) + result.update({'label': self.label}) + return result + + +class RedirectEdge(InterstateEdge): + """ An inter-state edge type used for rendering self-looping edges + on graph clusters in GraphViz. """ + + def __init__(self): + super(RedirectEdge, self).__init__() + self._dotOpts["arrowhead"] = "none" + + +############################################################################### +# Various classes to facilitate the detection of control flow elements (e.g., +# `for`, `if`, `while`) from state machines in SDFGs. + + +@make_properties +class ControlFlowScope: + + nodes_in_scope = Property( + dtype=set, + desc="Nodes contained in this scope, " + "including entry and exit nodes, in topological order.") + + def __init__(self, nodes_in_scope): + self.nodes_in_scope = nodes_in_scope + + def __contains__(self, node): + return node in self.nodes_in_scope + + def __iter__(self): + return iter(self.nodes_in_scope) + + +# make_properties will be called after adding cyclic class reference members +class LoopScope(ControlFlowScope): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.assignment = None + self.entry = None + self.back = None + self.exit = None + + +class ControlFlow: + pass + + +@make_properties +class LoopAssignment(ControlFlow): + + scope = Property(dtype=LoopScope) + edge = Property(dtype=Edge) + + def __init__(self, scope, edge, *args, **kwargs): + self.scope = scope + self.edge = edge + scope.assignment = self + super().__init__(*args, **kwargs) + + +@make_properties +class LoopEntry(ControlFlow): + + scope = Property(dtype=LoopScope) + edge = Property(dtype=Edge) + + def __init__(self, scope, edge, *args, **kwargs): + self.scope = scope + self.edge = edge + scope.entry = self + super().__init__(*args, **kwargs) + + +@make_properties +class LoopExit(ControlFlow): + + scope = Property(dtype=LoopScope) + edge = Property(dtype=Edge) + + def __init__(self, scope, edge, *args, **kwargs): + self.scope = scope + self.edge = edge + scope.exit = self + super().__init__(*args, **kwargs) + + +@make_properties +class LoopBack(ControlFlow): + + scope = Property(dtype=LoopScope) + edge = Property(dtype=Edge) + + def __init__(self, scope, edge, *args, **kwargs): + self.scope = scope + self.edge = edge + scope.back = self + super().__init__(*args, **kwargs) + + +# These will be assigned when the various control flow objects are created +LoopScope.assignment = Property(dtype=LoopAssignment, allow_none=True) +LoopScope.entry = Property(dtype=LoopEntry, allow_none=True) +LoopScope.back = Property(dtype=LoopBack, allow_none=True) +LoopScope.exit = Property(dtype=LoopExit, allow_none=True) +LoopScope = make_properties(LoopScope) + + +# Extra meta-object binding together then and else scopes. +# make_properties will be called after adding cyclic class reference members +class IfThenElse: + + entry = Property() + exit = Property() + + def __init__(self, entry, exit): + self.entry = entry + self.exit = exit + self.then_scope = None + self.else_scope = None + + +@make_properties +class IfEntry(ControlFlow): + + scope = Property(dtype=ControlFlowScope) + edge = Property(dtype=Edge) + + def __init__(self, scope, edge, *args, **kwargs): + self.scope = scope + self.edge = edge + scope.entry = self + super().__init__(*args, **kwargs) + + +@make_properties +class IfExit(ControlFlow): + + scope = Property(dtype=ControlFlowScope) + edge = Property(dtype=Edge) + + def __init__(self, scope, edge, *args, **kwargs): + self.scope = scope + self.edge = edge + scope.exit = self + super().__init__(*args, **kwargs) + + +@make_properties +class IfThenScope(ControlFlowScope): + + if_then_else = Property(dtype=IfThenElse) + entry = Property(dtype=IfEntry, allow_none=True) + exit = Property(dtype=IfExit, allow_none=True) + + def __init__(self, if_then_else, *args, **kwargs): + self.if_then_else = if_then_else + if_then_else.then_scope = self + self.entry = None + self.exit = None + super().__init__(*args, **kwargs) + + +@make_properties +class IfElseScope(ControlFlowScope): + + if_then_else = Property(dtype=IfThenElse) + entry = Property(dtype=IfEntry, allow_none=True) + exit = Property(dtype=IfExit, allow_none=True) + + def __init__(self, if_then_else, *args, **kwargs): + self.if_then_else = if_then_else + if_then_else.else_scope = self + self.entry = None + self.exit = None + super().__init__(*args, **kwargs) + + +# Cyclic class reference +IfThenElse.then_scope = Property(dtype=IfThenScope, allow_none=True) +IfThenElse.else_scope = Property(dtype=IfElseScope, allow_none=True) +IfThenElse = make_properties(IfThenElse) diff --git a/dace/graph/graph.py b/dace/graph/graph.py new file mode 100644 index 0000000000..03c487bd48 --- /dev/null +++ b/dace/graph/graph.py @@ -0,0 +1,711 @@ +""" Graph and multigraph implementations for DaCe. """ + +from collections import deque, OrderedDict +import itertools +import networkx as nx +from dace.types import deduplicate + + +class NodeNotFoundError(Exception): + pass + + +class EdgeNotFoundError(Exception): + pass + + +class Edge(object): + def __init__(self, src, dst, data): + self._src = src + self._dst = dst + self._data = data + + @property + def src(self): + return self._src + + @property + def dst(self): + return self._dst + + @property + def data(self): + return self._data + + def __iter__(self): + yield self._src + yield self._dst + yield self._data + + def toJSON(self, indent=0): + if self._data is None: + return "null" + return self._data.toJSON(indent) + + @staticmethod + def __len__(): + return 3 + + def reverse(self): + self._src, self._dst = self._dst, self._src + + +class MultiEdge(Edge): + def __init__(self, src, dst, data, key): + super(MultiEdge, self).__init__(src, dst, data) + self._key = key + + def toJSON(self, indent=0): + # we loose the key here, what is that even? + if self._data is None: + return "null" + return self._data.toJSON(indent) + + @property + def key(self): + return self._key + + +class MultiConnectorEdge(MultiEdge): + def __init__(self, src, src_conn, dst, dst_conn, data, key): + super(MultiConnectorEdge, self).__init__(src, dst, data, key) + self._src_conn = src_conn + self._dst_conn = dst_conn + + def toJSON(self, indent=0): + # we lose the key here, what is that even? + return ('%s' % ("null" + if self._data is None else self._data.toJSON(indent))) + + @property + def src_conn(self): + return self._src_conn + + @property + def src_connector(self): + return self._src_conn + + @property + def dst_conn(self): + return self._dst_conn + + @property + def dst_connector(self): + return self._dst_conn + + def __iter__(self): + yield self._src + yield self._src_conn + yield self._dst + yield self._dst_conn + yield self._data + + @staticmethod + def __len__(): + return 5 + + +class Graph(object): + def _not_implemented_error(self): + return NotImplementedError("Not implemented for " + str(type(self))) + + def toJSON(self, indent=0): + json = " " * indent + "{\n" + indent += 2 + json += " " * indent + "\"type\": \"" + type(self).__name__ + "\",\n" + json += " " * indent + "\"nodes\": [\n" + indent += 2 + for n in self.nodes(): + json += " " * indent + "{\n" + indent += 2 + json += " " * indent + "\"id\" : \"" + str( + self.node_id(n)) + "\",\n" + json += " " * indent + "\"attributes\" : " + n.toJSON(indent) + "\n" + indent -= 2 + if n == self.nodes()[-1]: + json += " " * indent + "}\n" + else: + json += " " * indent + "},\n" + indent -= 2 + json += " " * indent + "],\n" + + json += " " * indent + "\"edges\": [\n" + for e in self.edges(): + json += " " * indent + "{\n" + indent += 2 + json += " " * indent + "\"src\" : \"" + str(self.node_id( + e.src)) + "\",\n" + if isinstance(e, MultiConnectorEdge): + json += " " * indent + '"src_connector" : "%s",\n' % e.src_conn + json += " " * indent + "\"dst\" : \"" + str(self.node_id( + e.dst)) + "\",\n" + if isinstance(e, MultiConnectorEdge): + json += " " * indent + '"dst_connector" : "%s",\n' % e.dst_conn + json += " " * indent + "\"attributes\" : " + e.toJSON(indent) + "\n" + indent -= 2 + if e == self.edges()[-1]: + json += " " * indent + "}\n" + else: + json += " " * indent + "},\n" + indent -= 2 + json += " " * indent + "]\n" + json += " " * indent + "}\n" + return json + + def nodes(self): + """Returns an iterable to internal graph nodes.""" + raise self._not_implemented_error() + + def edges(self): + """Returns an iterable to internal graph edges.""" + raise self._not_implemented_error() + + def in_edges(self, node): + """Returns an iterable to Edge objects.""" + raise self._not_implemented_error() + + def out_edges(self, node): + """Returns an iterable to Edge objects.""" + raise self._not_implemented_error() + + def all_edges(self, *nodes): + """Returns an iterable to incoming and outgoing Edge objects.""" + result = set() + for node in nodes: + result.update(self.in_edges(node)) + result.update(self.out_edges(node)) + return list(result) + + def add_node(self, node): + """Adds node to the graph.""" + raise self._not_implemented_error() + + def add_nodes_from(self, node_list): + """Adds nodes from an iterable to the graph""" + for node in node_list: + self.add_node(node) + + def node_id(self, node): + """Returns a numeric node ID that corresponds to the node index in the + internal graph representation (unique).""" + for i, n in enumerate(self.nodes()): + if node == n: + return i + raise NodeNotFoundError(node) + + def add_edge(self, source, destination, data): + """Adds an edge to the graph containing the specified data. + Returns the added edge.""" + raise self._not_implemented_error() + + def remove_node(self, node): + """Removes the specified node.""" + raise self._not_implemented_error() + + def remove_nodes_from(self, node_list): + """Removes the nodes specified in an iterable.""" + for node in node_list: + self.remove_node(node) + + def remove_edge(self, edge): + """Removes the specified Edge object.""" + raise self._not_implemented_error() + + def edges_between(self, source, destination): + """Returns all edges that connect source and destination directly""" + raise self._not_implemented_error() + + def predecessors(self, node): + """Returns an iterable of nodes that have edges leading to the passed + node""" + return deduplicate([e.src for e in self.in_edges(node)]) + + def successors(self, node): + """Returns an iterable of nodes that have edges leading to the passed + node""" + return deduplicate([e.dst for e in self.out_edges(node)]) + + def neighbors(self, node): + return itertools.chain(self.predecessors(node), self.successors(node)) + + def in_degree(self, node): + """Returns the number of incoming edges to the specified node.""" + raise self._not_implemented_error() + + def out_degree(self, node): + """Returns the number of outgoing edges from the specified node.""" + raise self._not_implemented_error() + + def number_of_nodes(self): + """Returns the total number of nodes in the graph.""" + raise self._not_implemented_error() + + def number_of_edges(self): + """Returns the total number of edges in the graph.""" + raise self._not_implemented_error() + + def is_directed(self): + raise self._not_implemented_error() + + def is_multigraph(self): + raise self._not_implemented_error() + + def __iter__(self): + return iter(self.nodes()) + + def __len__(self): + """ Returns the total number of nodes in the graph (nx compatibility)""" + return self.number_of_nodes() + + def bfs_edges(self, node, reverse=False): + """Returns a generator over edges in the graph originating from the + passed node in BFS order""" + if isinstance(node, (tuple, list)): + queue = deque(node) + else: + queue = deque([node]) + visited = set() + while len(queue) > 0: + node = queue.popleft() + if node in visited: + continue + visited.add(node) + edges = (self.out_edges(node) + if not reverse else self.in_edges(node)) + for e in edges: + next_node = e.dst if not reverse else e.src + if next_node not in visited: + queue.append(next_node) + yield e + + def dfs_edges(G, source, condition=None): + """Traverse a graph (DFS) with an optional condition to filter out nodes + """ + if isinstance(source, list): nodes = source + else: nodes = [source] + visited = set() + for start in nodes: + if start in visited: + continue + visited.add(start) + stack = [(start, G.out_edges(start).__iter__())] + while stack: + parent, children = stack[-1] + try: + e = next(children) + if e.dst not in visited: + visited.add(e.dst) + if condition is None or condition( + e.src, e.dst, e.data): + yield e + stack.append((e.dst, + G.out_edges(e.dst).__iter__())) + except StopIteration: + stack.pop() + + def source_nodes(self): + """Returns nodes with no incoming edges.""" + return [n for n in self.nodes() if self.in_degree(n) == 0] + + def sink_nodes(self): + """Returns nodes with no outgoing edges.""" + return [n for n in self.nodes() if self.out_degree(n) == 0] + + def topological_sort(self, source=None): + """Returns nodes in topological order iff the graph contains exactly + one node with no incoming edges.""" + if source is not None: + sources = [source] + else: + sources = self.source_nodes() + if len(sources) == 0: + sources = [self.nodes()[0]] + #raise RuntimeError("No source nodes found") + if len(sources) > 1: + sources = [self.nodes()[0]] + #raise RuntimeError("Multiple source nodes found") + seen = OrderedDict() # No OrderedSet in Python + queue = deque(sources) + while len(queue) > 0: + node = queue.popleft() + seen[node] = None + for e in self.out_edges(node): + succ = e.dst + if succ not in seen: + seen[succ] = None + queue.append(succ) + return seen.keys() + + def all_simple_paths(self, source_node, dest_node): + """ Finds all simple paths (with no repeating nodes) from source_node + to dest_node """ + return nx.all_simple_paths(self._nx, source_node, dest_node) + + +class SubgraphView(Graph): + def __init__(self, graph, subgraph_nodes): + self._graph = graph + self._subgraph_nodes = subgraph_nodes + self._parallel_parent = None + + def is_parallel(self): + return self._parallel_parent != None + + def set_parallel_parent(self, parallel_parent): + self._parallel_parent = parallel_parent + + def get_parallel_parent(self): + return self._parallel_parent + + def nodes(self): + return self._subgraph_nodes + + def edges(self): + return [ + e for e in self._graph.edges() + if e.src in self._subgraph_nodes and e.dst in self._subgraph_nodes + ] + + def in_edges(self, node): + if node not in self._subgraph_nodes: + raise NodeNotFoundError + + return [ + e for e in self._graph.in_edges(node) + if e.src in self._subgraph_nodes + ] + + def out_edges(self, node): + if node not in self._subgraph_nodes: + raise NodeNotFoundError + + return [ + e for e in self._graph.out_edges(node) + if e.dst in self._subgraph_nodes + ] + + def add_node(self, node): + raise PermissionError + + def add_nodes_from(self, node_list): + raise PermissionError + + def node_id(self, node): + if node not in self._subgraph_nodes: + raise NodeNotFoundError + return self._graph.node_id(node) + + def add_edge(self, source, destination, data): + raise PermissionError + + def remove_node(self, node): + raise PermissionError + + def remove_nodes_from(self, node_list): + raise PermissionError + + def remove_edge(self, edge): + raise PermissionError + + def edges_between(self, source, destination): + if source not in self._subgraph_nodes or \ + destination not in self._subgraph_nodes: + raise NodeNotFoundError + return self._graph.edges_between(source, destination) + + def in_degree(self, node): + return len(self.in_edges(node)) + + def out_degree(self, node): + return len(self.out_edges(node)) + + def number_of_nodes(self): + return len(self._subgraph_nodes) + + def number_of_edges(self): + return len(self.edges()) + + def is_directed(self): + return self._graph.is_directed() + + def is_multigraph(self): + return self._graph.is_multigraph() + + +class DiGraph(Graph): + def __init__(self): + self._nx = nx.DiGraph() + + def nodes(self): + return self._nx.nodes() + + @staticmethod + def _from_nx(edge): + return Edge(edge[0], edge[1], edge[2]["data"]) + + def edges(self): + return [DiGraph._from_nx(e) for e in self._nx.edges()] + + def in_edges(self, node): + return [DiGraph._from_nx(e) for e in self._nx.in_edges()] + + def out_edges(self, node): + return [DiGraph._from_nx(e) for e in self._nx.out_edges()] + + def add_node(self, node): + return self._nx.add_node(node) + + def add_edge(self, source, destination, data): + return self._nx.add_edge(source, destination, data=data) + + def remove_node(self, node): + self._nx.remove_node(node) + + def remove_edge(self, edge): + self._nx.remove_edge(edge[0], edge[1]) + + def in_degree(self, node): + return self._nx.in_degree(node) + + def out_degree(self, node): + return self._nx.out_degree(node) + + def number_of_nodes(self): + return self._nx.number_of_nodes() + + def number_of_edges(self): + return self._nx.number_of_edges() + + def is_directed(self): + return True + + def is_multigraph(self): + return False + + def edges_between(self, source, destination): + return [e for e in self.out_edges(source) if e.dst == destination] + + def find_cycles(self): + return nx.simple_cycles(self._nx) + + +class MultiDiGraph(DiGraph): + def __init__(self): + self._nx = nx.MultiDiGraph() + + @staticmethod + def _from_nx(edge): + return MultiEdge(edge[0], edge[1], edge[3]["data"], edge[2]) + + def add_edge(self, source, destination, data): + key = self._nx.add_edge(source, destination, data=data) + return (source, destination, data, key) + + def remove_edge(self, edge): + self._nx.remove_edge(edge[0], edge[1], edge.key) + + def is_multigraph(self): + return True + + +class MultiDiConnectorGraph(MultiDiGraph): + def __init__(self): + super().__init__() + + @staticmethod + def _from_nx(edge): + return MultiConnectorEdge(edge[0], edge[3]["src_conn"], edge[1], + edge[3]["dst_conn"], edge[3]["data"], + edge[2]) + + def add_edge(self, source, src_connector, destination, dst_connector, + data): + key = self._nx.add_edge( + source, + destination, + data=data, + src_conn=src_connector, + dst_conn=dst_connector) + return (source, src_connector, destination, dst_connector, data, key) + + def remove_edge(self, edge): + self._nx.remove_edge(edge[0], edge[1], edge.key) + + def is_multigraph(self): + return True + + +class OrderedDiGraph(Graph): + """ Directed graph where nodes and edges are returned in the order they + were added. """ + + def __init__(self): + self._nx = nx.DiGraph() + # {node: ({in edge: None}, {out edges: None})} + self._nodes = OrderedDict() + # {(src, dst): edge} + self._edges = OrderedDict() + + @property + def nx(self): + return self._nx + + def node(self, id): + return list(self._nodes.keys())[id] + + def nodes(self): + return list(self._nodes.keys()) + + def edges(self): + return list(self._edges.values()) + + def in_edges(self, node): + return list(self._nodes[node][0].values()) + + def out_edges(self, node): + return list(self._nodes[node][1].values()) + + def add_node(self, node): + if node in self._nodes: + raise RuntimeError("Duplicate node added") + self._nodes[node] = (OrderedDict(), OrderedDict()) + self._nx.add_node(node) + + def add_edge(self, src, dst, data): + t = (src, dst) + if t in self._edges: + raise RuntimeError("Duplicate edge added") + if src not in self._nodes: + self.add_node(src) + if dst not in self._nodes: + self.add_node(dst) + edge = Edge(src, dst, data) + self._edges[t] = edge + self._nodes[src][1][t] = edge + self._nodes[dst][0][t] = edge + return self._nx.add_edge(src, dst, data=data) + + def remove_node(self, node): + for edge in itertools.chain(self.in_edges(node), self.out_edges(node)): + self.remove_edge(edge) + del self._nodes[node] + self._nx.remove_node(node) + + def remove_edge(self, edge): + src = edge.src + dst = edge.dst + t = (src, dst) + self._nx.remove_edge(src, dst) + del self._nodes[src][1][t] + del self._nodes[dst][0][t] + del self._edges[t] + + def in_degree(self, node): + return len(self._nodes[node][0]) + + def out_degree(self, node): + return len(self._nodes[node][1]) + + def number_of_nodes(self): + return len(self._nodes) + + def number_of_edges(self): + return len(self._edges) + + def is_directed(self): + return True + + def is_multigraph(self): + return False + + def find_cycles(self): + return nx.simple_cycles(self._nx) + + def edges_between(self, source, destination): + if source not in self.nodes(): return [] + return [e for e in self.out_edges(source) if e.dst == destination] + + def reverse(self): + """Reverses source and destination of all edges in the graph""" + raise self._not_implemented_error() + + +class OrderedMultiDiGraph(OrderedDiGraph): + """ Directed multigraph where nodes and edges are returned in the order + they were added. """ + + def __init__(self): + self._nx = nx.MultiDiGraph() + # {node: ({in edge: edge}, {out edge: edge})} + self._nodes = OrderedDict() + # {edge: edge} + self._edges = OrderedDict() + + def add_edge(self, src, dst, data): + key = self._nx.add_edge(src, dst, data=data) + edge = MultiEdge(src, dst, data, key) + if src not in self._nodes: + self.add_node(src) + if dst not in self._nodes: + self.add_node(dst) + self._nodes[src][1][edge] = edge + self._nodes[dst][0][edge] = edge + self._edges[edge] = edge + return edge + + def remove_edge(self, edge): + del self._edges[edge] + del self._nodes[edge.src][1][edge] + del self._nodes[edge.dst][0][edge] + self._nx.remove_edge(edge.src, edge.dst, edge.key) + + def reverse(self): + self._nx.reverse(False) + for e in self._edges.keys(): + e.reverse() + for n, (in_edges, out_edges) in self._nodes.items(): + self._nodes[n] = (out_edges, in_edges) + + def is_multigraph(self): + return True + + +class OrderedMultiDiConnectorGraph(OrderedMultiDiGraph): + """ Directed multigraph with node connectors (SDFG states), where nodes + and edges are returned in the order they were added. """ + + def __init__(self): + super().__init__() + + def add_edge(self, src, src_conn, dst, dst_conn, data): + key = self._nx.add_edge( + src, dst, data=data, src_conn=src_conn, dst_conn=dst_conn) + edge = MultiConnectorEdge(src, src_conn, dst, dst_conn, data, key) + if src not in self._nodes: + self.add_node(src) + if dst not in self._nodes: + self.add_node(dst) + self._nodes[src][1][edge] = edge + self._nodes[dst][0][edge] = edge + self._edges[edge] = edge + return edge + + def add_nedge(self, src, dst, data): + """ Adds an edge without (value=None) connectors. """ + return self.add_edge(src, None, dst, None, data) + + def remove_edge(self, edge): + del self._edges[edge] + del self._nodes[edge.src][1][edge] + del self._nodes[edge.dst][0][edge] + self._nx.remove_edge(edge.src, edge.dst, edge.key) + + def reverse(self): + self._nx.reverse(False) + for e in self._edges.keys(): + e.reverse() + for n, (in_edges, out_edges) in self._nodes.items(): + self._nodes[n] = (out_edges, in_edges) + + def is_multigraph(self): + return True diff --git a/dace/graph/labeling.py b/dace/graph/labeling.py new file mode 100644 index 0000000000..c47734cb48 --- /dev/null +++ b/dace/graph/labeling.py @@ -0,0 +1,813 @@ +""" Functionality relating to Memlet propagation (deducing external memlets + from internal memory accesses and scope ranges). """ + +import copy +import itertools +import functools +import networkx as nx +import sympy +import unittest +import math + +from dace import data, subsets, symbolic, types +from dace.memlet import Memlet +from dace.graph import nodes, nxutil +from dace.graph.graph import OrderedMultiDiGraph +from dace.transformation import pattern_matching + + +class MemletPattern(object): + """ A pattern match on a memlet subset that can be used for propagation. + """ + s_patterns = [] + s_dependencies = {} + + @staticmethod + def patterns(): + return [p() for p in MemletPattern.s_patterns] + + @staticmethod + def register_pattern(clazz, depends=None): + if not issubclass(clazz, MemletPattern): + raise TypeError + MemletPattern.s_patterns.append(clazz) + + @staticmethod + def unregister_pattern(clazz): + if not issubclass(clazz, MemletPattern): + raise TypeError + MemletPattern.s_patterns.remove(clazz) + + #################################################### + + def match(self, expressions, variable_context, node_range, orig_edges): + raise NotImplementedError + + def propagate(self, array, expressions, node_range): + raise NotImplementedError + + +class SeparableMemletPattern(object): + """ Memlet pattern that can be applied to each of the dimensions + separately. """ + + s_smpatterns = [] + + @staticmethod + def register_pattern(cls): + if not issubclass(cls, SeparableMemletPattern): raise TypeError + if cls not in SeparableMemletPattern.s_smpatterns: + SeparableMemletPattern.s_smpatterns.append(cls) + + @staticmethod + def unregister_pattern(cls): + SeparableMemletPattern.s_smpatterns.remove(cls) + + def match(self, dim_exprs, variable_context, node_range, orig_edges, + dim_index, total_dims): + raise NotImplementedError + + def propagate(self, array, dim_exprs, node_range): + raise NotImplementedError + + +class SeparableMemlet(MemletPattern): + """ Meta-memlet pattern that applies all separable memlet patterns. """ + + def match(self, expressions, variable_context, node_range, orig_edges): + # Assuming correct dimensionality in each of the expressions + data_dims = len(expressions[0]) + self.patterns_per_dim = [None] * data_dims + + overapprox_range = subsets.Range([(rb.approx if isinstance( + rb, symbolic.SymExpr) else rb, re.approx if isinstance( + re, symbolic.SymExpr) else re, rs.approx if isinstance( + rs, symbolic.SymExpr) else rs) + for rb, re, rs in node_range]) + + for dim in range(data_dims): + + dexprs = [] + for expr in expressions: + if isinstance(expr[dim], symbolic.SymExpr): + dexprs.append(expr[dim].approx) + elif isinstance(expr[dim], tuple): + dexprs.append( + (expr[dim][0].approx + if isinstance(expr[dim][0], symbolic.SymExpr) else + expr[dim][0], expr[dim][1].approx + if isinstance(expr[dim][1], symbolic.SymExpr) else + expr[dim][1], expr[dim][2].approx + if isinstance(expr[dim][2], + symbolic.SymExpr) else expr[dim][2])) + else: + dexprs.append(expr[dim]) + + for pattern_class in SeparableMemletPattern.s_smpatterns: + smpattern = pattern_class() + if smpattern.match(dexprs, variable_context, overapprox_range, + orig_edges, dim, data_dims): + self.patterns_per_dim[dim] = smpattern + break + + return None not in self.patterns_per_dim + + def propagate(self, array, expressions, node_range): + result = [(None, None, None)] * len(self.patterns_per_dim) + + overapprox_range = subsets.Range([(rb.approx if isinstance( + rb, symbolic.SymExpr) else rb, re.approx if isinstance( + re, symbolic.SymExpr) else re, rs.approx if isinstance( + rs, symbolic.SymExpr) else rs) + for rb, re, rs in node_range]) + + for i, smpattern in enumerate(self.patterns_per_dim): + + dexprs = [] + for expr in expressions: + if isinstance(expr[i], symbolic.SymExpr): + dexprs.append(expr[i].approx) + elif isinstance(expr[i], tuple): + dexprs.append((expr[i][0].approx if isinstance( + expr[i][0], + symbolic.SymExpr) else expr[i][0], expr[i][1].approx + if isinstance(expr[i][1], symbolic.SymExpr) + else expr[i][1], expr[i][2].approx + if isinstance(expr[i][2], symbolic.SymExpr) + else expr[i][2], expr.tile_sizes[i])) + else: + dexprs.append(expr[i]) + + result[i] = smpattern.propagate(array, dexprs, overapprox_range) + + # TODO(later): Not necessarily Range (general integer sets) + return subsets.Range(result) + + +MemletPattern.register_pattern(SeparableMemlet) + + +class AffineSMemlet(SeparableMemletPattern): + """ Separable memlet pattern that matches affine expressions, i.e., + of the form `a * {index} + b`. + """ + + def match(self, dim_exprs, variable_context, node_range, orig_edges, + dim_index, total_dims): + + params = variable_context[-1] # Why only last element? + # Create wildcards for multiplication and addition + a = sympy.Wild('a', exclude=params) + b = sympy.Wild('b', exclude=params) + + self.param = None + self.paramind = None + self.mult = None + self.add_min = None + self.add_max = None + self.constant_min = None + self.constant_max = None + + # Obtain vector length + self.veclen = None + if dim_index == total_dims - 1: + for e in orig_edges: + self.veclen = e.veclen + if self.veclen is None: + self.veclen = 1 + ###################### + + # Special case: Get the total internal access range + # If this range matches (0, rs), we say that the propagated skip is 1 + self.internal_range = set() + + for dexpr in dim_exprs: + subexprs = None + step = None + if isinstance(dexpr, sympy.Basic): # Affine index + subexprs = [dexpr] + + elif isinstance(dexpr, tuple) and len(dexpr) == 3: # Affine range + subexprs = [dexpr[0], dexpr[1]] + step = dexpr[2] + + if subexprs is None: # Something else + return False + + for i, subexpr in enumerate(subexprs): + # Try to match an affine expression with a parameter + param = None + pind = -1 + for indp, p in enumerate(params): + matches = subexpr.match(a * p + b) + if param is None and matches is None: + continue + elif param is not None and matches is not None: + return False # Only one parameter may match + elif matches is not None: + multiplier = matches[a] + addition = matches[b] + param = p + pind = indp + + if param is None: + return False # A parameter must match + if self.param is not None and param != self.param: + return False # There can only be one parameter + if self.mult is not None and multiplier != self.mult: + return False # Multiplier must be the same + + self.param = param + self.paramind = pind + self.multiplier = multiplier + + # If this is one expression + if len(subexprs) == 1: + self.internal_range.add(addition) + elif i == 0: # Range begin + brb = addition + elif i == 1: # Range end + bre = addition + + if len(subexprs) > 1: + self.internal_range.add((brb, bre)) + + if step is not None: + if self.param in step.free_symbols: + return False # Step must be independent of parameter + + node_rb, node_re, node_rs = node_range[self.paramind] + if node_rs != 1: + # Map ranges where the last index is not known + # exactly are not supported by this pattern. + return False + + if self.param is None: # and self.constant_min is None: + return False + + return True + + def propagate(self, array, dim_exprs, node_range): + # Compute last index in map according to range definition + node_rb, node_re, node_rs = node_range[self.paramind] # node_rs = 1 + node_rlen = node_re - node_rb + 1 + + if isinstance(dim_exprs, list): + dim_exprs = dim_exprs[0] + + if isinstance(dim_exprs, tuple): + + if len(dim_exprs) == 3: + rb, re, rs = dim_exprs + rt = '1' + elif len(dim_exprs) == 4: + rb, re, rs, rt = dim_exprs + else: + raise NotImplementedError + + rb = symbolic.pystr_to_symbolic(rb).expand() + re = symbolic.pystr_to_symbolic(re).expand() + rs = symbolic.pystr_to_symbolic(rs).expand() + rt = symbolic.pystr_to_symbolic(rt).expand() + else: + rb, re = (dim_exprs.expand(), dim_exprs.expand()) + rs = 1 + rt = 1 + + result_begin = rb.subs(self.param, node_rb).expand() + result_end = re.subs(self.param, node_re).expand() + + # Experimental + # This should be using sympy.floor + memlet_start_pts = ((re - rt + 1 - rb) / rs) + 1 + memlet_rlen = memlet_start_pts.expand() * rt + interval_len = (result_end - result_begin + 1) * self.veclen + num_elements = node_rlen * memlet_rlen + + if (interval_len == num_elements + or interval_len.expand() == num_elements): + # Continuous access + result_skip = 1 + result_tile = 1 + else: + if rt == 1: + result_skip = (result_end - result_begin - re + rb) / ( + node_re - node_rb) + try: + if result_skip < 1: + result_skip = 1 + except: + pass + result_tile = result_end - result_begin + 1 - ( + node_rlen - 1) * result_skip + else: + candidate_skip = rs + candidate_tile = rt * node_rlen + candidate_lstart_pt = result_end - result_begin + 1 - candidate_tile + if (candidate_lstart_pt / (num_elements / candidate_tile - 1) + ).simplify() == candidate_skip: + result_skip = rs + result_tile = rt * node_rlen + else: + result_skip = rs / node_rlen + result_tile = rt + + if result_skip == result_tile or result_skip == 1: + result_skip = 1 + result_tile = 1 + + result_begin = sympy.simplify(result_begin) + result_end = sympy.simplify(result_end) + result_skip = sympy.simplify(result_skip) + result_tile = sympy.simplify(result_tile) + + return (result_begin, result_end, result_skip, result_tile) + + +SeparableMemletPattern.register_pattern(AffineSMemlet) + + +class ModuloSMemlet(SeparableMemletPattern): + """ Separable memlet pattern that matches modulo expressions, i.e., + of the form `f(x) % N`. + + Acts as a meta-pattern: Finds the underlying pattern for `f(x)`. + """ + + def match(self, dim_exprs, variable_context, node_range, orig_edges, + dim_index, total_dims): + # Pattern does not support unions of expressions + if len(dim_exprs) > 1: return False + dexpr = dim_exprs[0] + # Pattern does not support ranges + if not isinstance(dexpr, sympy.Basic): return False + + # Create wildcards + val = sympy.Wild('val') + mod = sympy.Wild('mod', exclude=variable_context[-1]) + + # Try to match an affine expression + matches = dexpr.match(val % mod) + if matches is None or len(matches) != 2: + return False + + self.subexpr = matches[val] + self.modulo = matches[mod] + + self.subpattern = None + for pattern_class in SeparableMemletPattern.s_smpatterns: + smpattern = pattern_class() + if smpattern.match([self.subexpr], variable_context, node_range, + orig_edges, dim_index, total_dims): + self.subpattern = smpattern + + return self.subpattern is not None + + def propagate(self, array, dim_exprs, node_range): + se_range = self.subpattern.propagate(array, [self.subexpr], node_range) + + # Apply modulo on start and end ranges + try: + if se_range[0] < 0: + se_range = (0, self.modulo, se_range[2]) + except TypeError: # cannot determine truth value of Relational + print('WARNING: Cannot evaluate relational %s, assuming true.' % + (se_range[0] < 0)) + try: + if se_range[1] > self.modulo: + se_range = (0, self.modulo, se_range[2]) + except TypeError: # cannot determine truth value of Relational + print('WARNING: Cannot evaluate relational %s, assuming true.' % + (se_range[1] > self.modulo)) + + return se_range + + +SeparableMemletPattern.register_pattern(ModuloSMemlet) + + +class ConstantSMemlet(SeparableMemletPattern): + """ Separable memlet pattern that matches constant (i.e., unrelated to + current scope) expressions. + """ + + def match(self, dim_exprs, variable_context, node_range, orig_edges, + dim_index, total_dims): + # Pattern does not support unions of expressions. TODO: Support + if len(dim_exprs) > 1: return False + dexpr = dim_exprs[0] + + # Create a wildcard that excludes current map's parameters + cst = sympy.Wild('cst', exclude=variable_context[-1]) + + # Range case + if isinstance(dexpr, tuple) and len(dexpr) == 3: + # Try to match a constant expression for the range + for rngelem in dexpr: + if types.isconstant(rngelem): + continue + + matches = rngelem.match(cst) + if matches is None or len(matches) != 1: + return False + if not matches[cst].is_constant(): + return False + + else: # Single element case + # Try to match a constant expression + if not types.isconstant(dexpr): + matches = dexpr.match(cst) + if matches is None or len(matches) != 1: + return False + if not matches[cst].is_constant(): + return False + + return True + + def propagate(self, array, dim_exprs, node_range): + if isinstance(dim_exprs[0], tuple): + return dim_exprs[0] # Already in range format + # Convert index to range format + return (dim_exprs[0], dim_exprs[0], 1) + + +SeparableMemletPattern.register_pattern(ConstantSMemlet) + + +class GenericSMemlet(SeparableMemletPattern): + """ Separable memlet pattern that detects any expression, and propagates + interval bounds. Used as a last resort. """ + + def match(self, dim_exprs, variable_context, node_range, orig_edges, + dim_index, total_dims): + + self.params = variable_context[-1] + + # Always matches + return True + + def propagate(self, array, dim_exprs, node_range): + + result_begin = None + result_end = None + + # Iterate over the node dimensions + for idx, node_r in enumerate(node_range): + + # Get dimension range + if len(node_r) == 3: + node_rb, node_re, node_rs = node_r + elif len(node_r) == 4: + node_rb, node_re, node_rs, _ = node_r + else: + raise NotImplementedError + + # Get true range end + lastindex = node_re + if node_rs != 1: + lastindex = symbolic.pystr_to_symbolic( + '%s + int_floor(%s - %s, %s) * %s' % + (symbolic.symstr(node_rb), symbolic.symstr(node_re), + symbolic.symstr(node_rb), symbolic.symstr(node_rs), + symbolic.symstr(node_rs))) + + if isinstance(dim_exprs, list): + dim_exprs = dim_exprs[0] + + if isinstance(dim_exprs, tuple): + + if len(dim_exprs) == 3: + rb, re, rs = dim_exprs + elif len(dim_exprs) == 4: + rb, re, rs, _ = dim_exprs + else: + raise NotImplementedError + + rb = symbolic.pystr_to_symbolic(rb) + re = symbolic.pystr_to_symbolic(re) + rs = symbolic.pystr_to_symbolic(rs) + + else: + rb, re = (dim_exprs, dim_exprs) + + if result_begin is None: + result_begin = rb.subs(self.params[idx], node_rb) + else: + result_begin = result_begin.subs(self.params[idx], node_rb) + if result_end is None: + result_end = re.subs(self.params[idx], lastindex) + else: + result_end = result_end.subs(self.params[idx], lastindex) + + result_skip = 1 + result_tile = 1 + + return (result_begin, result_end, result_skip, result_tile) + + +SeparableMemletPattern.register_pattern(GenericSMemlet) + + +def _subexpr(dexpr, repldict): + if isinstance(dexpr, tuple): + return tuple(_subexpr(d, repldict) for d in dexpr) + elif isinstance(dexpr, symbolic.SymExpr): + return dexpr.expr.subs(repldict) + else: + return dexpr.subs(repldict) + + +class ConstantRangeMemlet(MemletPattern): + """ Memlet pattern that matches arbitrary expressions with constant range. + """ + + def match(self, expressions, variable_context, node_range, orig_edges): + constant_range = True + for dim in node_range: + for rngelem in dim: # For (begin, end, skip) + if not types.isconstant(rngelem) and not isinstance( + rngelem, sympy.Number): + constant_range = False + break + if not constant_range: + return False + + self.params = variable_context[-1] + + return True + + # TODO: An integer set library should shine here (unify indices) + def propagate(self, array, expressions, node_range): + rng = [(None, None, 1)] * len(array.shape) + node_range_gen = (range(rb, re, rs) for rb, re, rs in node_range) + for ndind in itertools.product(*tuple(node_range_gen)): + repldict = {p: ndind[i] for i, p in enumerate(self.params)} + for expr in expressions: + for dim, dexpr in enumerate(expr): + evaldexpr = _subexpr(dexpr, repldict) + rb, re, rs = rng[dim] + if rb is None: + rng[dim] = (evaldexpr, evaldexpr, 1) + else: + if evaldexpr < rb: + rng[dim] = (evaldexpr, re, rs) + if evaldexpr > re: # The +1 is because ranges are exclusive + rng[dim] = (rb, evaldexpr, rs) + + return subsets.Range(rng) + + +# ConstantRangeMemlet is slow, so it should be evaluated last +MemletPattern.register_pattern(ConstantRangeMemlet) + + +def propagate_labels_sdfg(sdfg): + """ Propagates memlets throughout an entire given SDFG. + @note: This is an in-place operation on the SDFG. + """ + for state in sdfg.nodes(): + _propagate_labels(state, sdfg) + + +def _propagate_labels(g, sdfg): + """ Propagates memlets throughout one SDFG state. + @param g: The state to propagate in. + @param sdfg: The SDFG in which the state is situated. + @note: This is an in-place operation on the SDFG state. + """ + patterns = MemletPattern.patterns() + + # Algorithm: + # 1. Start propagating information from tasklets outwards (their edges + # are hardcoded). + # NOTE: This process can be performed in parallel. + # 2. Traverse the neighboring nodes (topological sort, first forward to + # outputs and then backward to inputs). + # There are four possibilities: + # a. If the neighboring node is a tasklet, skip (such edges are + # immutable) + # b. If the neighboring node is an array, make sure it is the correct + # array. Otherwise, throw a mismatch exception. + # c. If the neighboring node is a scope node, and its other edges are + # not set, set the results per-array, using the union of the + # obtained ranges in the previous depth. + # d. If the neighboring node is a scope node, and its other edges are + # already set, verify the results per-array, using the union of the + # obtained ranges in the previous depth. + # NOTE: The SDFG creation process ensures that all edges in the + # multigraph are tagged with the appropriate array. In any case + # of ambiguity, the function raises an exception. + # 3. For each edge in the multigraph, collect results and group by array assigned to edge. + # Accumulate information about each array in the target node. + scope_dict = g.scope_dict() + + def stop_at(parent, child): + # Transients should only propagate in the direction of the + # non-transient data + if isinstance(parent, + nodes.AccessNode) and parent.desc(sdfg).transient: + for _, _, _, _, memlet in g.edges_between(parent, child): + if parent.data != memlet.data: + return True + return False + if isinstance(child, nodes.AccessNode): + return False + return True + + array_data = {} # type: dict(node -> dict(data -> list(Subset))) + tasklet_nodes = [ + node for node in g.nodes() if (isinstance(node, nodes.CodeNode) or ( + isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient)) + ] + # Step 1: Direction - To output + for start_node in tasklet_nodes: + for node in nxutil.dfs_topological_sort( + g, start_node, condition=stop_at): + _propagate_node(sdfg, g, node, array_data, patterns, scope_dict, + True) + # Step 1: Direction - To input + array_data = {} + g.reverse() + for node in nxutil.dfs_topological_sort( + g, tasklet_nodes, condition=stop_at): + _propagate_node(sdfg, g, node, array_data, patterns, scope_dict) + + # To support networkx 1.11 + g.reverse() + + +# External API +def propagate_memlet(dfg_state, memlet: Memlet, scope_node: nodes.EntryNode, + union_inner_edges: bool): + """ Tries to propagate a memlet through a scope (computes the image of + the memlet function applied on an integer set of, e.g., a map range) + and returns a new memlet object. + @param dfg_state: An SDFGState object representing the graph. + @param memlet: The memlet adjacent to the scope node from the inside. + @param scope_node: A scope entry or exit node. + @param union_inner_edges: True if the propagation should take other + neighboring internal memlets within the same + scope into account. + """ + if isinstance(scope_node, nodes.EntryNode): + neighboring_edges = dfg_state.out_edges(scope_node) + elif isinstance(scope_node, nodes.ExitNode): + neighboring_edges = dfg_state.in_edges(scope_node) + else: + raise TypeError('Trying to propagate through a non-scope node') + + # Find other adjacent edges within the connected to the scope node + # and union their subsets + if union_inner_edges: + aggdata = [ + e.data for e in neighboring_edges + if e.data.data == memlet.data and e.data != memlet + ] + else: + aggdata = [] + + aggdata.append(memlet) + + new_subset = _propagate_edge(dfg_state.parent, None, + scope_node, None, memlet, aggdata, + MemletPattern.patterns(), None) + + new_memlet = copy.copy(memlet) + new_memlet.subset = new_subset + new_memlet.other_subset = None + + # Number of accesses in the propagated memlet is the sum of the internal + # number of accesses times the size of the map range set + new_memlet.num_accesses = ( + sum(m.num_accesses for m in aggdata) * functools.reduce( + lambda a, b: a * b, scope_node.map.range.size(), 1)) + + return new_memlet + + +def _propagate_node(sdfg, + g, + node, + array_data, + patterns, + scope_dict, + write=False): + # Step 2: Propagate edges + # If this is a tasklet, we only propagate to adjacent nodes and not modify edges + # Special case: starting from reduction, no need for external nodes to compute edges + if (not isinstance(node, nodes.CodeNode) + and not isinstance(node, nodes.AccessNode) and node in array_data): + # Otherwise (if primitive), use current node information and accumulated data + # on arrays to set the memlets per edge + for _, _, target, _, memlet in g.out_edges(node): + # Option (a) + if (isinstance(target, nodes.CodeNode)): + continue + + if not isinstance(memlet, Memlet): + raise AttributeError('Edge does not contain a memlet') + + aggdata = None + if node in array_data: + if memlet.data in array_data[node]: + aggdata = array_data[node][memlet.data] + + wcr = None + if aggdata is not None: + for m in aggdata: + if m.wcr is not None: + wcr = (m.wcr, m.wcr_identity) + break + + # Compute candidate edge + candidate = _propagate_edge(sdfg, g, node, target, memlet, aggdata, + patterns, not write) + if candidate is None: + continue + + # Option (b) + if isinstance(target, nodes.AccessNode): + # Check for data mismatch + if target.data != memlet.data: #and not target.desc.transient: + raise LookupError( + 'Mismatch between edge data %s and data node %s' % + (memlet.data, target.data)) + + # Options (c), (d) + else: + pass + + # Set new edge value + memlet.subset = candidate + + # Number of accesses in the propagated memlet is the sum of the internal + # number of accesses times the size of the map range set + memlet.num_accesses = ( + sum(m.num_accesses for m in aggdata) * functools.reduce( + lambda a, b: a * b, node.map.range.size(), 1)) + + # Set WCR, if necessary + if wcr is not None: + memlet.wcr, memlet.wcr_identity = wcr + + # Step 3: Accumulate edge information in adjacent node, grouped by array + for _, _, target, _, memlet in g.out_edges(node): + if (isinstance(target, nodes.CodeNode)): + continue + + if not isinstance(memlet, Memlet): + raise AttributeError('Edge does not contain a memlet') + + # Transients propagate only towards the data they are writing to + if isinstance(node, nodes.AccessNode) and node.data == memlet.data: + continue + + # No data + if memlet.subset is None: + continue + #if isinstance(memlet, subsets.SequentialDependency): + # continue + + # Accumulate data information on target node + if target not in array_data: + array_data[target] = {} + if memlet.data not in array_data[target]: + array_data[target][memlet.data] = [] + array_data[target][memlet.data].append(memlet) + + +def _propagate_edge(sdfg, g, u, v, memlet, aggdata, patterns, reversed): + if ((isinstance(u, nodes.EntryNode) or isinstance(u, nodes.ExitNode))): + mapnode = u.map + + if aggdata is None: + return None + + # Collect data about edge + data = memlet.data + expr = [edge.subset for edge in aggdata] + + if memlet.data not in sdfg.arrays: + raise KeyError('Data descriptor (Array, Stream) "%s" not defined ' + 'in SDFG.' % memlet.data) + + for pattern in patterns: + if pattern.match( + expr, + [[symbolic.pystr_to_symbolic(p) for p in mapnode.params]], + mapnode.range, aggdata): # Only one level of context + return pattern.propagate(sdfg.arrays[memlet.data], expr, + mapnode.range) + + # No patterns found. Emit a warning and propagate the entire array + print('WARNING: Cannot find appropriate memlet pattern to propagate %s' + % str(expr)) + + return subsets.Range.from_array(sdfg.arrays[memlet.data]) + elif isinstance(u, nodes.ConsumeEntry) or isinstance(u, nodes.ConsumeExit): + + # Nothing to analyze/propagate in consume + return subsets.Range.from_array(sdfg.arrays[memlet.data]) + + else: + raise NotImplementedError('Unimplemented primitive: %s' % type(u)) diff --git a/dace/graph/nodes.py b/dace/graph/nodes.py new file mode 100644 index 0000000000..58a30a6f9c --- /dev/null +++ b/dace/graph/nodes.py @@ -0,0 +1,749 @@ +""" Contains classes implementing the different types of nodes of the stateful + dataflow multigraph representation. """ + +import ast +from copy import deepcopy as dcpy +import itertools +from typing import Set +from dace.graph import dot, graph +from dace.frontend.python.astutils import unparse +from dace.properties import (Property, CodeProperty, LambdaProperty, + ParamsProperty, RangeProperty, DebugInfoProperty, + SetProperty, make_properties, indirect_properties, + DataProperty, SymbolicProperty) +from dace.frontend.operations import detect_reduction_type +from dace import data, subsets as sbs, types +import pickle + +# ----------------------------------------------------------------------------- + + +@make_properties +class Node(object): + """ Base node class. """ + + in_connectors = SetProperty( + str, default=set(), desc="A set of input connectors for this node.") + out_connectors = SetProperty( + str, default=set(), desc="A set of output connectors for this node.") + + def __init__(self, in_connectors=set(), out_connectors=set()): + self.in_connectors = in_connectors + self.out_connectors = out_connectors + + def __str__(self): + if hasattr(self, 'label'): + return self.label + else: + return type(self).__name__ + + def validate(self, sdfg, state): + pass + + def toJSON(self, indent=0): + labelstr = str(self) + typestr = str(type(self).__name__) + inconn = "[" + ",".join( + ['"' + str(x) + '"' for x in self.in_connectors]) + "]" + outconn = "[" + ",".join( + ['"' + str(x) + '"' for x in self.out_connectors]) + "]" + json = " " * indent + "{ \"label\": \"" + labelstr + json += "\", \"type\": \"" + typestr + "\", \"in_connectors\": " + inconn + json += ", \"out_connectors\" :" + outconn + json += "}\n" + return json + + def __repr__(self): + return type(self).__name__ + ' (' + self.__str__() + ')' + + def add_in_connector(self, connector_name: str): + """ Adds a new input connector to the node. The operation will fail if + a connector (either input or output) with the same name already + exists in the node. + + @param connector_name: The name of the new connector. + @return: True if the operation is successful, otherwise False. + """ + + if (connector_name in self.in_connectors + or connector_name in self.out_connectors): + return False + connectors = self.in_connectors + connectors.add(connector_name) + self.in_connectors = connectors + return True + + def add_out_connector(self, connector_name: str): + """ Adds a new output connector to the node. The operation will fail if + a connector (either input or output) with the same name already + exists in the node. + + @param connector_name: The name of the new connector. + @return: True if the operation is successful, otherwise False. + """ + + if (connector_name in self.in_connectors + or connector_name in self.out_connectors): + return False + connectors = self.out_connectors + connectors.add(connector_name) + self.out_connectors = connectors + return True + + def remove_in_connector(self, connector_name: str): + """ Removes an input connector from the node. + @param connector_name: The name of the connector to remove. + @return: True if the operation was successful. + """ + + if connector_name in self.in_connectors: + connectors = self.in_connectors + connectors.remove(connector_name) + self.in_connectors = connectors + return True + + def remove_out_connector(self, connector_name: str): + """ Removes an output connector from the node. + @param connector_name: The name of the connector to remove. + @return: True if the operation was successful. + """ + + if connector_name in self.out_connectors: + connectors = self.out_connectors + connectors.remove(connector_name) + self.out_connectors = connectors + return True + + def _next_connector_int(self) -> int: + """ Returns the next unused connector ID (as an integer). Used for + filling connectors when adding edges to scopes. """ + next_number = 1 + for conn in itertools.chain(self.in_connectors, self.out_connectors): + if conn.startswith('IN_'): + cconn = conn[3:] + elif conn.startswith('OUT_'): + cconn = conn[4:] + else: + continue + try: + curconn = int(cconn) + if curconn >= next_number: + next_number = curconn + 1 + except TypeError: # not integral + continue + return next_number + + def next_connector(self) -> str: + """ Returns the next unused connector ID (as a string). Used for + filling connectors when adding edges to scopes. """ + return str(self._next_connector_int()) + + def last_connector(self) -> str: + """ Returns the last used connector ID (as a string). Used for + filling connectors when adding edges to scopes. """ + return str(self._next_connector_int() - 1) + + +# ------------------------------------------------------------------------------ + + +@make_properties +class AccessNode(Node): + """ A node that accesses data in the SDFG. Denoted by a circular shape. """ + + access = Property( + enum=types.AccessType, + desc="Type of access to this array", + default=types.AccessType.ReadWrite) + setzero = Property(dtype=bool, desc="Initialize to zero", default=False) + debuginfo2 = DebugInfoProperty() + data = DataProperty(desc="Data (array, stream, scalar) to access") + + def __init__(self, data, access=types.AccessType.ReadWrite, + debuginfo=None): + super(AccessNode, self).__init__() + + # Properties + self.debuginfo2 = debuginfo + self.access = access + if not isinstance(data, str): + raise TypeError('Data for AccessNode must be a string') + self.data = data + + def __deepcopy__(self, memo): + node = object.__new__(AccessNode) + node._access = self._access + node._data = self._data + node._setzero = self._setzero + node._in_connectors = self._in_connectors + node._out_connectors = self._out_connectors + node.debuginfo2 = dcpy(self.debuginfo2) + return node + + @property + def label(self): + return self.data + + def __label__(self, sdfg, state): + return self.data + + def desc(self, sdfg): + from dace.sdfg import SDFGState, ScopeSubgraphView + if isinstance(sdfg, (SDFGState, ScopeSubgraphView)): + sdfg = sdfg.parent + return sdfg.arrays[self.data] + + def draw_node(self, sdfg, graph): + desc = self.desc(sdfg) + if isinstance(desc, data.Stream): + return dot.draw_node( + sdfg, graph, self, shape="oval", style='dashed') + elif desc.transient: + return dot.draw_node(sdfg, graph, self, shape="oval") + else: + return dot.draw_node(sdfg, graph, self, shape="oval", style='bold') + + def validate(self, sdfg, state): + if self.data not in sdfg.arrays: + raise KeyError('Array "%s" not found in SDFG' % self.data) + + +# ------------------------------------------------------------------------------ + + +class CodeNode(Node): + """ A node that contains runnable code with acyclic external data + dependencies. May either be a tasklet or a nested SDFG, and + denoted by an octagonal shape. """ + pass + + +@make_properties +class Tasklet(CodeNode): + """ A node that contains a tasklet: a functional computation procedure + that can only access external data specified using connectors. + + Tasklets may be implemented in Python, C++, or any supported + language by the code generator. + """ + + label = Property(dtype=str, desc="Name of the tasklet") + language = Property(enum=types.Language, default=types.Language.Python) + code = CodeProperty(desc="Tasklet code") + code_global = CodeProperty( + desc="Global scope code needed for tasklet execution", default="") + code_init = CodeProperty( + desc="Extra code that is called on DaCe runtime initialization", + default="") + code_exit = CodeProperty( + desc="Extra code that is called on DaCe runtime cleanup", default="") + location = Property( + dtype=str, desc="Tasklet execution location descriptor") + debuginfo = DebugInfoProperty() + + def __init__(self, + label, + inputs=set(), + outputs=set(), + code="", + language=types.Language.Python, + code_global="", + code_init="", + code_exit="", + location="-1", + debuginfo=None): + super(Tasklet, self).__init__(inputs, outputs) + + # Properties + self.label = label + self.language = language + self.code = code + self.location = location + self.code_global = code_global + self.code_init = code_init + self.code_exit = code_exit + self.debuginfo = debuginfo + + @property + def name(self): + return self._label + + def draw_node(self, sdfg, graph): + return dot.draw_node(sdfg, graph, self, shape="octagon") + + def validate(self, sdfg, state): + if not data.validate_name(self.label): + raise NameError('Invalid tasklet name "%s"' % self.label) + for in_conn in self.in_connectors: + if not data.validate_name(in_conn): + raise NameError('Invalid input connector "%s"' % in_conn) + for out_conn in self.out_connectors: + if not data.validate_name(out_conn): + raise NameError('Invalid output connector "%s"' % out_conn) + + def __str__(self): + if not self.label: + return "--Empty--" + else: + return self.label + + +class EmptyTasklet(Tasklet): + """ A special tasklet that contains no code. Used for filling empty states + in an SDFG. """ + + def __init__(self, label=""): + super(EmptyTasklet, self).__init__(label) + + def draw_node(self, sdfg, graph): + return dot.draw_node(sdfg, graph, self, style="invis", shape="octagon") + + def validate(self, sdfg, state): + pass + + +# ------------------------------------------------------------------------------ + + +@make_properties +class NestedSDFG(CodeNode): + """ An SDFG state node that contains an SDFG of its own, runnable using + the data dependencies specified using its connectors. + + It is encouraged to use nested SDFGs instead of coarse-grained tasklets + since they are analyzable with respect to transformations. + + @note: A nested SDFG cannot create recursion (one of its parent SDFGs). + """ + + label = Property(dtype=str, desc="Name of the SDFG") + # NOTE: We cannot use SDFG as the type because of an import loop + sdfg = Property(dtype=graph.OrderedDiGraph, desc="The SDFG") + schedule = Property( + dtype=types.ScheduleType, + desc="SDFG schedule", + enum=types.ScheduleType, + from_string=lambda x: types.ScheduleType[x]) + location = Property(dtype=str, desc="SDFG execution location descriptor") + debuginfo = DebugInfoProperty() + is_collapsed = Property( + dtype=bool, + desc="Show this node/scope/state as collapsed", + default=False) + + def __init__(self, + label, + sdfg, + inputs: Set[str], + outputs: Set[str], + schedule=types.ScheduleType.Default, + location="-1", + debuginfo=None): + super(NestedSDFG, self).__init__(inputs, outputs) + + # Properties + self.label = label + self.sdfg = sdfg + self.schedule = schedule + self.location = location + self.debuginfo = debuginfo + + def draw_node(self, sdfg, graph): + return dot.draw_node(sdfg, graph, self, shape="doubleoctagon") + + def __str__(self): + if not self.label: + return "SDFG" + else: + return self.label + + def validate(self, sdfg, state): + if not data.validate_name(self.label): + raise NameError('Invalid nested SDFG name "%s"' % self.label) + for in_conn in self.in_connectors: + if not data.validate_name(in_conn): + raise NameError('Invalid input connector "%s"' % in_conn) + for out_conn in self.out_connectors: + if not data.validate_name(out_conn): + raise NameError('Invalid output connector "%s"' % out_conn) + + # Recursively validate nested SDFG + self.sdfg.validate() + + +# ------------------------------------------------------------------------------ + + +# Scope entry class +class EntryNode(Node): + """ A type of node that opens a scope (e.g., Map or Consume). """ + + def validate(self, sdfg, state): + self.map.validate(sdfg, state, self) + + +# ------------------------------------------------------------------------------ + + +# Scope exit class +class ExitNode(Node): + """ A type of node that closes a scope (e.g., Map or Consume). """ + + def validate(self, sdfg, state): + self.map.validate(sdfg, state, self) + + +# ------------------------------------------------------------------------------ + + +class MapEntry(EntryNode): + """ Node that opens a Map scope. + @see: Map + """ + + def __init__(self, map, dynamic_inputs=set()): + super(MapEntry, self).__init__(dynamic_inputs) + if map is None: + raise ValueError("Map for MapEntry can not be None.") + self._map = map + self._map_depth = 0 + + @property + def map(self): + return self._map + + @map.setter + def map(self, val): + self._map = val + + def draw_node(self, sdfg, graph): + if self.is_collapsed: + return dot.draw_node(sdfg, graph, self, shape="hexagon") + return dot.draw_node(sdfg, graph, self, shape="trapezium") + + def __str__(self): + return str(self.map) + + +class MapExit(ExitNode): + """ Node that closes a Map scope. + @see: Map + """ + + def __init__(self, map): + super(MapExit, self).__init__() + if map is None: + raise ValueError("Map for MapExit can not be None.") + self._map = map + + @property + def map(self): + return self._map + + @map.setter + def map(self, val): + self._map = val + + def draw_node(self, sdfg, graph): + return dot.draw_node(sdfg, graph, self, shape="invtrapezium") + + def __str__(self): + return str(self.map) + + +@make_properties +class Map(object): + """ A Map is a two-node representation of parametric graphs, containing + an integer set by which the contents (nodes dominated by an entry + node and post-dominated by an exit node) are replicated. + + Maps contain a `schedule` property, which specifies how the scope + should be scheduled (execution order). Code generators can use the + schedule property to generate appropriate code, e.g., GPU kernels. + """ + from dace.codegen.instrumentation.perfsettings import PerfSettings + + # List of (editable) properties + label = Property(dtype=str, desc="Label of the map") + params = ParamsProperty(desc="Mapped parameters") + range = RangeProperty(desc="Ranges of map parameters") + # order = OrderProperty(desc="Order of map dimensions", unmapped=True) + schedule = Property( + dtype=types.ScheduleType, + desc="Map schedule", + enum=types.ScheduleType, + from_string=lambda x: types.ScheduleType[x]) + is_async = Property(dtype=bool, desc="Map asynchronous evaluation") + unroll = Property(dtype=bool, desc="Map unrolling") + flatten = Property(dtype=bool, desc="Map loop flattening") + fence_instrumentation = Property( + dtype=bool, desc="Disable instrumentation in all subnodes") + papi_counters = Property( + dtype=list, + desc="List of PAPI counter preset identifiers.", + default=PerfSettings.perf_default_papi_counters()) + debuginfo = DebugInfoProperty() + is_collapsed = Property( + dtype=bool, + desc="Show this node/scope/state as collapsed", + default=False) + + # We cannot have multiple consecutive papi start/stops inside the same thread. The following variable is used to recognize the map that started the counters. + _has_papi_counters = False + _can_be_supersection_start = True # We must have supersections synchronized. + + def __init__(self, + label, + params, + ndrange, + schedule=types.ScheduleType.Default, + unroll=False, + is_async=False, + flatten=False, + fence_instrumentation=False, + debuginfo=None): + super(Map, self).__init__() + + # Assign properties + self.label = label + self.schedule = schedule + self.unroll = unroll + self.is_async = is_async + self.flatten = flatten + self.params = params + self.range = ndrange + self.debuginfo = debuginfo + self._fence_instrumentation = fence_instrumentation + + def __str__(self): + return self.label + "[" + ", ".join([ + "{}={}".format(i, r) + for i, r in zip(self._params, + [sbs.Range.dim_to_string(d) for d in self._range]) + ]) + "]" + + def validate(self, sdfg, state, node): + if not data.validate_name(self.label): + raise NameError('Invalid map name "%s"' % self.label) + + def get_param_num(self): + """ Returns the number of map dimension parameters/symbols. """ + return len(self.params) + + +# Indirect Map properties to MapEntry and MapExit +MapEntry = indirect_properties(Map, lambda obj: obj.map)(MapEntry) +MapExit = indirect_properties(Map, lambda obj: obj.map)(MapExit) + +# ------------------------------------------------------------------------------ + + +class ConsumeEntry(EntryNode): + """ Node that opens a Consume scope. + @see: Consume + """ + + def __init__(self, consume, dynamic_inputs=set()): + super(ConsumeEntry, self).__init__(dynamic_inputs) + if consume is None: + raise ValueError("Consume for ConsumeEntry can not be None.") + self._consume = consume + self.add_in_connector('IN_stream') + self.add_out_connector('OUT_stream') + + @property + def map(self): + return self._consume.as_map() + + @property + def consume(self): + return self._consume + + @consume.setter + def consume(self, val): + self._consume = val + + def draw_node(self, sdfg, graph): + if self.is_collapsed: + return dot.draw_node( + sdfg, graph, self, shape="hexagon", style='dashed') + return dot.draw_node( + sdfg, graph, self, shape="trapezium", style='dashed') + + def __str__(self): + return str(self.consume) + + +class ConsumeExit(ExitNode): + """ Node that closes a Consume scope. + @see: Consume + """ + + def __init__(self, consume): + super(ConsumeExit, self).__init__() + if consume is None: + raise ValueError("Consume for ConsumeExit can not be None.") + self._consume = consume + + @property + def map(self): + return self._consume.as_map() + + @property + def consume(self): + return self._consume + + @consume.setter + def consume(self, val): + self._consume = val + + def draw_node(self, sdfg, graph): + return dot.draw_node( + sdfg, graph, self, shape="invtrapezium", style='dashed') + + def __str__(self): + return str(self.consume) + + +@make_properties +class Consume(object): + """ Consume is a scope, like `Map`, that is a part of the parametric + graph extension of the SDFG. It creates a producer-consumer + relationship between the input stream and the scope subgraph. The + subgraph is scheduled to a given number of processing elements + for processing, and they will try to pop elements from the input + stream until a given quiescence condition is reached. """ + + # Properties + label = Property(dtype=str, desc="Name of the consume node") + pe_index = Property(dtype=str, desc="Processing element identifier") + num_pes = SymbolicProperty(desc="Number of processing elements") + condition = CodeProperty(desc="Quiescence condition", allow_none=True) + language = Property(enum=types.Language, default=types.Language.Python) + schedule = Property( + dtype=types.ScheduleType, + desc="Consume schedule", + enum=types.ScheduleType, + from_string=lambda x: types.ScheduleType[x]) + chunksize = Property( + dtype=int, + desc="Maximal size of elements to consume at a time", + default=1) + debuginfo = DebugInfoProperty() + is_collapsed = Property( + dtype=bool, + desc="Show this node/scope/state as collapsed", + default=False) + + def as_map(self): + """ Compatibility function that allows to view the consume as a map, + mainly in memlet propagation. """ + return Map(self.label, [self.pe_index], + sbs.Range([(0, self.num_pes - 1, 1)]), self.schedule) + + def __init__(self, + label, + pe_tuple, + condition, + schedule=types.ScheduleType.Default, + chunksize=1, + debuginfo=None): + super(Consume, self).__init__() + + # Properties + self.label = label + self.pe_index, self.num_pes = pe_tuple + self.condition = condition + self.schedule = schedule + self.chunksize = chunksize + self.debuginfo = debuginfo + + def __str__(self): + if self.condition is not None: + return ("%s [%s=0:%s], Condition: %s" % + (self._label, self.pe_index, self.num_pes, + CodeProperty.to_string(self.condition))) + else: + return ( + "%s [%s=0:%s]" % (self._label, self.pe_index, self.num_pes)) + + def validate(self, sdfg, state, node): + if not data.validate_name(self.label): + raise NameError('Invalid consume name "%s"' % self.label) + + def get_param_num(self): + """ Returns the number of consume dimension parameters/symbols. """ + return 1 + + +# Redirect Consume properties to ConsumeEntry and ConsumeExit +ConsumeEntry = indirect_properties(Consume, + lambda obj: obj.consume)(ConsumeEntry) +ConsumeExit = indirect_properties(Consume, + lambda obj: obj.consume)(ConsumeExit) + +# ------------------------------------------------------------------------------ + + +@make_properties +class Reduce(Node): + """ An SDFG node that reduces an N-dimensional array to an + (N-k)-dimensional array, with a list of axes to reduce and + a reduction binary function. """ + from dace.codegen.instrumentation.perfsettings import PerfSettings + + # Properties + axes = Property(dtype=tuple, allow_none=True) + wcr = LambdaProperty() + identity = Property(dtype=object, allow_none=True) + schedule = Property( + dtype=types.ScheduleType, + desc="Reduction execution policy", + enum=types.ScheduleType, + from_string=lambda x: types.ScheduleType[x]) + + papi_counters = Property( + dtype=list, + desc="List of PAPI counter preset identifiers.", + default=PerfSettings.perf_default_papi_counters()) + debuginfo = DebugInfoProperty() + + def __init__(self, + wcr, + axes, + wcr_identity=None, + schedule=types.ScheduleType.Default, + debuginfo=None): + super(Reduce, self).__init__() + self.wcr = wcr # type: ast._Lambda + self.axes = axes + self.identity = wcr_identity + self.schedule = schedule + self.debuginfo = debuginfo + + def draw_node(self, sdfg, state): + return dot.draw_node(sdfg, state, self, shape="invtriangle") + + def __str__(self): + # Autodetect reduction type + redtype = detect_reduction_type(self.wcr) + if redtype == types.ReductionType.Custom: + wcrstr = unparse(ast.parse(self.wcr).body[0].value.body) + else: + wcrstr = str(redtype) + wcrstr = wcrstr[wcrstr.find('.') + 1:] # Skip "ReductionType." + + return 'Op: {op}, Axes: {axes}'.format( + axes=('all' if self.axes is None else str(self.axes)), op=wcrstr) + + def __label__(self, sdfg, state): + # Autodetect reduction type + redtype = detect_reduction_type(self.wcr) + if redtype == types.ReductionType.Custom: + wcrstr = unparse(ast.parse(self.wcr).body[0].value.body) + else: + wcrstr = str(redtype) + wcrstr = wcrstr[wcrstr.find('.') + 1:] # Skip "ReductionType." + + return 'Op: {op}\nAxes: {axes}'.format( + axes=('all' if self.axes is None else str(self.axes)), op=wcrstr) diff --git a/dace/graph/nxutil.py b/dace/graph/nxutil.py new file mode 100644 index 0000000000..feab768075 --- /dev/null +++ b/dace/graph/nxutil.py @@ -0,0 +1,668 @@ +from ast import Subscript +from collections import deque +import copy +import itertools +import re +import os +from typing import Callable, List, Union +from string import ascii_uppercase +import networkx as nx + +import dace +from dace import sdfg, types, symbolic +from dace.config import Config +from dace.graph import nodes, graph as gr + +params = List[dace.symbolic.symbol] +ranges = List[Union[dace.subsets.Range, dace.subsets.Indices]] + + +class CannotExpand(Exception): + pass + + +def node_path_graph(*args): + """ Generates a path graph passing through the input nodes. + + The function generates a graph using as nodes the input arguments. + Subsequently, it creates a path passing through all the nodes, in + the same order as they were given in the function input. + + @param *args: Variable number of nodes or a list of nodes. + @return: A directed graph based on the input arguments. + @rtype: gr.OrderedDiGraph + """ + + # 1. Create new networkx directed graph. + path = gr.OrderedDiGraph() + # 2. Place input nodes in a list. + if len(args) == 1 and isinstance(args[0], list): + # Input is a single list of nodes. + input_nodes = args[0] + else: + # Input is a variable number of nodes. + input_nodes = list(args) + # 3. Add nodes to the graph. + path.add_nodes_from(input_nodes) + # 4. Add path edges to the graph. + for i in range(len(input_nodes) - 1): + path.add_edge(input_nodes[i], input_nodes[i + 1], None) + # 5. Return the graph. + return path + + +def depth_limited_search(source, depth): + """ Return best node and its value using a limited-depth Search (depth- + limited DFS). """ + value = source.evaluate() + if depth == 0: + return source, value + + candidate = source + candidate_value = value + + # Node, depth, children generator + stack = [(source, 0, source.children_iter())] + while stack: + node, cur_depth, children = stack[-1] + try: + child = next(children) + child_val = child.evaluate() + # Check for best candidate + if child_val > candidate_value: + candidate = child + candidate_value = child_val + + if cur_depth < depth - 1: + stack.append((child, cur_depth + 1, child.children_iter())) + except StopIteration: + stack.pop() + + # Return maximal candidate + return candidate, candidate_value + + +def depth_limited_dfs_iter(source, depth): + """ Produce nodes in a Depth-Limited DFS. """ + if depth == 0: + yield source + return + + # Node, depth, children generator + stack = [(source, 0, source.children_iter())] + while stack: + node, cur_depth, children = stack[-1] + try: + child = next(children) + yield child + + if cur_depth < depth - 1: + stack.append((child, cur_depth + 1, child.children_iter())) + except StopIteration: + stack.pop() + + +def dfs_topological_sort(G, sources=None, parent=False, condition=None): + """ Produce nodes in a depth-first topological ordering. + + The function produces nodes in a depth-first topological ordering + (DFS to make sure maps are visited properly), with the condition + that each node visited had all its predecessors visited. Applies + for DAGs only. + + @param G: An input DiGraph (assumed acyclic). + @param sources: (optional) node or list of nodes that + specify starting point(s) for depth-first search and return + edges in the component reachable from source. + @return: A generator of edges in the lastvisit depth-first-search. + + @note: Based on http://www.ics.uci.edu/~eppstein/PADS/DFS.py + by D. Eppstein, July 2004. + + @note: If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + """ + if sources is None: + # produce edges for all components + nodes = G + else: + # produce edges for components with source + try: + nodes = iter(sources) + except TypeError: + nodes = [sources] + + visited = set() + for start in nodes: + if start in visited: + continue + yield start + visited.add(start) + stack = [(start, iter(G.neighbors(start)))] + while stack: + parent, children = stack[-1] + try: + child = next(children) + if child not in visited: + # Make sure that all predecessors have been visited + skip = False + for pred in G.predecessors(child): + if pred not in visited: + skip = True + break + if skip: + continue + + visited.add(child) + if condition is None or condition(parent, child): + yield child + stack.append((child, iter(G.neighbors(child)))) + except StopIteration: + stack.pop() + + +def dfs_conditional(G, source, condition, reversed=False): + """ Traverse a graph (DFS) only through edges that match a condition. """ + if isinstance(source, list): nodes = source + else: nodes = [source] + + def in_edges_reversed(graph): + def _in_edges_reversed(node): + for e in graph.in_edges(node): + ecpy = copy.copy(e) + ecpy.reverse() + yield ecpy + + return _in_edges_reversed + + get_children = G.out_edges if not reversed else in_edges_reversed(G) + + visited = set() + for start in nodes: + if start in visited: + continue + visited.add(start) + stack = [(start, get_children(start).__iter__())] + while stack: + parent, children = stack[-1] + try: + e = next(children) + if e.dst not in visited: + visited.add(e.dst) + if condition(e.src, e.dst, e.data): + yield e + stack.append((e.dst, get_children(e.dst).__iter__())) + except StopIteration: + stack.pop() + + +def bfs_conditional(G, source, condition): + """ Traverse a graph (BFS) only through edges that match a condition. """ + + visited = set([source]) + queue = deque([(source, G.out_edges(source).__iter__())]) + while queue: + parent, children = queue[0] + try: + e = next(children) + if e.dst not in visited: + visited.add(e.dst) + if condition(e.src, e.dst, e.data): + yield e + queue.append((e.dst, G.out_edges(child).__iter__())) + except StopIteration: + queue.popleft() + + +def traverse_sdfg_scope(G, source, yield_edges=True): + """ Traverse an SDFG scope (nodes dominated by a ScopeEntry and + post-dominated by a ScopeExit). + @param G: Input graph (assumed SDFGState). + @param source: Source node. + @param yield_edges: If True, returned generator yields edges + as well as nodes. + @return: A generator that iterates over the scope nodes (or edges). + """ + + if not isinstance(source, nodes.EntryNode): + raise SyntaxError('Source should be an entry node') + + visited = set() + visited.add(source) + + if yield_edges: + for e in G.out_edges(source): + yield tuple(e) + (1, ) + else: + yield source, 1 + + stack = [(1, source, iter(G.out_edges(source)))] + while stack: + scope, parent, children = stack[-1] + try: + e = next(children) + child = e.dst + if child not in visited: + # Make sure that all predecessors have been visited + skip = False + for pred in G.predecessors(child): + if pred not in visited: + skip = True + break + if skip: + continue + + if yield_edges: + if not (isinstance(child, nodes.ExitNode) and scope == 1): + for e in G.out_edges(child): + yield tuple(e) + (scope, ) + else: + yield child, scope + + visited.add(child) + if isinstance(child, nodes.EntryNode): + stack.append((scope + 1, child, iter(G.out_edges(child)))) + elif isinstance(child, nodes.ExitNode): + if scope > 1: # Don't traverse beyond scope + stack.append((scope - 1, child, iter( + G.out_edges(child)))) + else: + stack.append((scope, child, iter(G.out_edges(child)))) + except StopIteration: + stack.pop() + + +def gen_label(prefix=""): + """ Generates a label as A,B,C,...,Z,AA,AB,... """ + indices = [0] + while True: + label = "".join([ascii_uppercase[i] for i in indices]) + yield prefix + label + indices[0] += 1 + for pos, val in enumerate(indices): + if val == len(ascii_uppercase): + indices[pos] = 0 + if len(indices) == pos + 1: + indices.append(1) + else: + indices[pos + 1] += 1 + + +def indstr(x): + try: + return int(x) + except TypeError: # int() argument must be a string, a bytes-like object or a number, not [X] + return str(x) + + +def range_to_str(ranges, limit_length=50): + """ Converts one or multiple range tuples to a string. """ + + try: + len(ranges[0]) + except TypeError: + ranges = [ranges] + + def convert_index(r): + if len(r) == 3: + if r[2] != 1: + return "{}:{}:{}".format( + symbolic.symstr(r[0]), symbolic.symstr(r[1]), + symbolic.symstr(r[2])) + else: + return "{}:{}".format( + symbolic.symstr(r[0]), symbolic.symstr(r[1])) + else: + raise ValueError("Unsupported range: " + str(r)) + + s = [] + for r in ranges: + s.append(convert_index(r)) + + res = ', '.join(s) + if limit_length is not None: + if not Config.get_bool('renderer', 'fulledges') and \ + len(res) > limit_length: + res = '...' + + return "[" + res + "]" + + +def str_to_range(rangeStr): + """ Converts a range string into a range tuple. """ + if rangeStr[0] != "[" or rangeStr[-1] != "]": + raise ValueError("Invalid range " + rangeStr) + rangeStr = re.sub("[\[\] ]", "", rangeStr) + dimensions = rangeStr.split(",") + ranges = [None] * len(dimensions) + for i, r in enumerate(dimensions): + entries = r.split(":") + numArgs = len(entries) + if numArgs < 2 or numArgs > 3: + raise ValueError( + "Range string should contain one or two separators (received " + + str(r) + ")") + iMin = None + iMax = None + step = None + if entries[0]: + iMin = entries[0] + if entries[1]: + iMax = entries[1] + if numArgs == 3: + if not entries[2]: + raise ValueError("Stride for range cannot be empty") + step = entries[2] + ranges[i] = (iMin, iMax, step) + return ranges + + +def make_list(val): + """ If a scalar or string is passed make it a list, otherwise do nothing. + """ + try: + len(val) + if not isinstance(val, str): + return val + except TypeError: + pass + return [val] + + +def make_2d(ranges): + """ If a 1D list is passed, make it 2D, otherwise do nothing. """ + if isinstance(ranges, Subscript): + return [ranges] + firstElem = ranges[0] + try: + if isinstance(firstElem, Subscript): + return ranges + len(firstElem) + if not isinstance(firstElem, str): + return ranges + except TypeError: + pass + return [ranges] + + +def label_of(obj): + """ Fetches the label of an object, or generates one if it doesn't exist. + """ + try: + return obj.label + except AttributeError: + try: + return obj.name + except AttributeError: + try: + return next(type(obj)._nameGen) + except AttributeError: + type(obj)._nameGen = gen_label(type(obj).__name__ + " ") + obj.label = next(type(obj)._nameGen) + return obj.label + + +def fullrange(ndslice, var_size): + """ Returns True iff the ND-slice represents the full array size. """ + for dim, (b, e, s) in zip(var_size, ndslice): + if b != 0 or e != symbolic.pystr_to_symbolic( + types.symbol_name_or_value(dim)) or s != 1: + return False + return True + + +def change_edge_dest( + graph: dace.graph.graph.OrderedDiGraph, + node_a: Union[dace.graph.nodes.Node, + dace.graph.graph.OrderedMultiDiConnectorGraph], + node_b: Union[dace.graph.nodes.Node, + dace.graph.graph.OrderedMultiDiConnectorGraph]): + """ Changes the destination of edges from node A to node B. + + The function finds all edges in the graph that have node A as their + destination. It then creates a new edge for each one found, + using the same source nodes and data, but node B as the destination. + Afterwards, it deletes the edges found and inserts the new ones into + the graph. + + @param graph: The graph upon which the edge transformations will be + applied. + @param node_a: The original destination of the edges. + @param node_b: The new destination of the edges to be transformed. + """ + + # Create new incoming edges to node B, by copying the incoming edges to + # node A and setting their destination to node B. + edges = list(graph.in_edges(node_a)) + for e in edges: + # Delete the incoming edges to node A from the graph. + graph.remove_edge(e) + # Insert the new edges to the graph. + if isinstance(e, gr.MultiConnectorEdge): + # dst_conn = e.dst_conn + # if e.dst_conn is not None: + # # Remove connector from node A. + # node_a.remove_in_connector(e.dst_conn) + # # Insert connector to node B. + # if (not node_b.add_in_connector(dst_conn) and isinstance( + # node_b, (dace.graph.nodes.CodeNode, + # dace.graph.nodes.MapEntry))): + # while not node_b.add_in_connector(dst_conn): + # dst_conn = dst_conn + '_' + # graph.add_edge(e.src, e.src_conn, node_b, dst_conn, e.data) + graph.add_edge(e.src, e.src_conn, node_b, e.dst_conn, e.data) + else: + graph.add_edge(e.src, node_b, e.data) + + +def change_edge_src( + graph: dace.graph.graph.OrderedDiGraph, + node_a: Union[dace.graph.nodes.Node, + dace.graph.graph.OrderedMultiDiConnectorGraph], + node_b: Union[dace.graph.nodes.Node, + dace.graph.graph.OrderedMultiDiConnectorGraph]): + """ Changes the sources of edges from node A to node B. + + The function finds all edges in the graph that have node A as their + source. It then creates a new edge for each one found, using the same + destination nodes and data, but node B as the source. Afterwards, it + deletes the edges + found and inserts the new ones into the graph. + + @param graph: The graph upon which the edge transformations will be + applied. + @param node_a: The original source of the edges to be transformed. + @param node_b: The new source of the edges to be transformed. + """ + + # Create new outgoing edges from node B, by copying the outgoing edges from + # node A and setting their source to node B. + edges = list(graph.out_edges(node_a)) + for e in edges: + # Delete the outgoing edges from node A from the graph. + graph.remove_edge(e) + # Insert the new edges to the graph. + if isinstance(e, gr.MultiConnectorEdge): + # src_conn = e.src_conn + # if e.src_conn is not None: + # # Remove connector from node A. + # node_a.remove_out_connector(e.src_conn) + # # Insert connector to node B. + # if (not node_b.add_out_connector(src_conn) and isinstance( + # node_b, (dace.graph.nodes.CodeNode, + # dace.graph.nodes.MapExit))): + # while not node_b.add_out_connector(src_conn): + # src_conn = src_conn + '_' + # graph.add_edge(node_b, src_conn, e.dst, e.dst_conn, e.data) + graph.add_edge(node_b, e.src_conn, e.dst, e.dst_conn, e.data) + else: + graph.add_edge(node_b, e.dst, e.data) + + +def find_source_nodes(graph): + """ Finds the source nodes of a graph. + + The function finds the source nodes of a graph, i.e. the nodes with + zero in-degree. + + @param graph: The graph whose source nodes are being searched for. + @return: A list of the source nodes found. + """ + return [n for n in graph.nodes() if graph.in_degree(n) == 0] + + +def find_sink_nodes(graph): + """ Finds the sink nodes of a graph. + + The function finds the sink nodes of a graph, i.e. the nodes with zero out-degree. + + @param graph: The graph whose sink nodes are being searched for. + @return: A list of the sink nodes found. + """ + return [n for n in graph.nodes() if graph.out_degree(n) == 0] + + +def replace_subgraph(graph: dace.graph.graph.OrderedDiGraph, + old: dace.graph.graph.OrderedDiGraph, + new: dace.graph.graph.OrderedDiGraph): + """ Replaces a subgraph of a graph with a new one. If replacement is not + possible, it returns False. + + The function replaces the 'old' subgraph of the input graph with the + 'new' subgraph. Both the 'old' and the 'new' subgraphs must have + unique source and sink nodes. Graph edges incoming to the source of + the 'old' subgraph have their destination changed to the source of + the 'new subgraph. Likewise, graph edges outgoing from the sink of + the 'old subgraph have their source changed to the sink of the 'new' + subgraph. + + @param graph: The graph upon which the replacement will be applied. + @param old: The subgraph to be replaced. + @param new: The replacement subgraph. + + @return: True if the replacement succeeded, otherwise False. + """ + + # 1. Find the source node of 'old' subgraph. + # 1.1. Retrieve the source nodes of the 'old' subgraph. + old_source_nodes = find_source_nodes(old) + # 1.2. Verify the existence of a unique source in the 'old' subgraph. + if len(old_source_nodes) != 1: + return False + old_source = old_source_nodes[0] + + # 2. Find the sink node of the 'old' subgraph. + # 2.1. Retrieve the sink nodes of the 'old' subgraph. + old_sink_nodes = find_sink_nodes(old) + # 2.2. Verify the existence of a unique sink in the 'old' subgraph. + if len(old_sink_nodes) != 1: + return False + old_sink = old_sink_nodes[0] + + # 3. Find the source node of 'new' subgraph. + # 3.1. Retrieve the source nodes of the 'new' subgraph. + new_source_nodes = find_source_nodes(new) + # 3.2. Verify the existence of a unique source in the 'new' subgraph. + if len(new_source_nodes) != 1: + return False + new_source = new_source_nodes[0] + + # 4. Find the sink node of the 'new' subgraph. + # 4.1. Retrieve the sink nodes of the 'new' subgraph. + new_sink_nodes = find_sink_nodes(new) + # 4.2. Verify the existence of a unique sink in the 'new' subgraph. + if len(new_sink_nodes) != 1: + return False + new_sink = new_sink_nodes[0] + + # 5. Add the 'new' subgraph to the graph. + # 5.1. Add the nodes of the 'new' subgraph to the graph. + graph.add_nodes_from(new.nodes()) + # 5.2. Add the edges of the 'new' subgraph to the graph. + for e in new.edges(): + graph.add_edge(*e) + + # 6. Create new incoming edges to the source of the 'new' subgraph. + change_edge_dest(graph, old_source, new_source) + + # 7. Create new outgoing edges from the sink of the 'new' subgraph. + change_edge_src(graph, old_sink, new_sink) + + # 8. Remove all nodes of the 'old' subgraph from the graph. + graph.remove_nodes_from(old.nodes()) + + # 10. Subgraph replacement has succeeded. Return true. + return True + + +def merge_maps(graph: dace.graph.graph.OrderedMultiDiConnectorGraph, + outer_map_entry: dace.graph.nodes.MapEntry, + outer_map_exit: dace.graph.nodes.MapExit, + inner_map_entry: dace.graph.nodes.MapEntry, + inner_map_exit: dace.graph.nodes.MapExit, + param_merge: Callable[[params, params], + params] = lambda p1, p2: p1 + p2, + range_merge: Callable[[ + ranges, ranges + ], ranges] = lambda r1, r2: type(r1)(r1.ranges + r2.ranges) + ) -> (dace.graph.nodes.MapEntry, dace.graph.nodes.MapExit): + """ Merges two maps (their entries and exits). It is assumed that the + operation is valid. """ + + outer_map = outer_map_entry.map + inner_map = inner_map_entry.map + + # Create merged map by inheriting attributes from outer map and using + # the merge functions for parameters and ranges. + merged_map = dace.graph.nodes.Map( + label='_merged_' + outer_map.label + '_' + inner_map.label, + params=param_merge(outer_map.params, inner_map.params), + ndrange=range_merge(outer_map.range, inner_map.range), + schedule=outer_map.schedule, + unroll=outer_map.unroll, + is_async=outer_map.is_async, + flatten=outer_map.flatten, + debuginfo=outer_map.debuginfo) + + merged_entry = dace.graph.nodes.MapEntry(merged_map) + merged_entry.in_connectors = outer_map_entry.in_connectors + merged_entry.out_connectors = outer_map_entry.out_connectors + + merged_exit = dace.graph.nodes.MapExit(merged_map) + merged_exit.in_connectors = outer_map_exit.in_connectors + merged_exit.out_connectors = outer_map_exit.out_connectors + + graph.add_nodes_from([merged_entry, merged_exit]) + + # Redirect inner in edges. + inner_in_edges = graph.out_edges(inner_map_entry) + for edge in graph.edges_between(outer_map_entry, inner_map_entry): + in_conn_num = edge.dst_conn[3:] + out_conn = 'OUT_' + in_conn_num + inner_edge = [e for e in inner_in_edges if e.src_conn == out_conn][0] + graph.remove_edge(edge) + graph.remove_edge(inner_edge) + graph.add_edge(merged_entry, edge.src_conn, inner_edge.dst, + inner_edge.dst_conn, inner_edge.data) + + # Redirect inner out edges. + inner_out_edges = graph.in_edges(inner_map_exit) + for edge in graph.edges_between(inner_map_exit, outer_map_exit): + out_conn_num = edge.src_conn[4:] + in_conn = 'IN_' + out_conn_num + inner_edge = [e for e in inner_out_edges if e.dst_conn == in_conn][0] + graph.remove_edge(edge) + graph.remove_edge(inner_edge) + graph.add_edge(inner_edge.src, inner_edge.src_conn, merged_exit, + edge.dst_conn, inner_edge.data) + + # Redirect outer edges. + change_edge_dest(graph, outer_map_entry, merged_entry) + change_edge_src(graph, outer_map_exit, merged_exit) + + # Clean-up + graph.remove_nodes_from( + [outer_map_entry, outer_map_exit, inner_map_entry, inner_map_exit]) + + return merged_entry, merged_exit diff --git a/dace/memlet.py b/dace/memlet.py new file mode 100644 index 0000000000..3950e465d9 --- /dev/null +++ b/dace/memlet.py @@ -0,0 +1,278 @@ +import ast +from functools import reduce +import operator +import copy as cp + +import dace +from dace import data as dt, subsets, symbolic, types +from dace.frontend.operations import detect_reduction_type +from dace.frontend.python.astutils import unparse +from dace.properties import ( + Property, make_properties, DataProperty, ShapeProperty, SubsetProperty, + SymbolicProperty, TypeClassProperty, DebugInfoProperty, LambdaProperty) + + +@make_properties +class Memlet(object): + """ Data movement object. Represents the data, the subset moved, and the + manner it is reindexed (`other_subset`) into the destination. + If there are multiple conflicting writes, this object also specifies + how they are resolved with a lambda function. + """ + + # Properties + veclen = Property(dtype=int, desc="Vector length") + num_accesses = SymbolicProperty() + subset = SubsetProperty() + other_subset = SubsetProperty(allow_none=True) + data = DataProperty() + debuginfo = DebugInfoProperty() + wcr = LambdaProperty(allow_none=True) + wcr_identity = Property(dtype=object, default=None, allow_none=True) + wcr_conflict = Property(dtype=bool, default=True) + + def __init__(self, + data, + num_accesses, + subset, + vector_length, + wcr=None, + wcr_identity=None, + other_subset=None, + debuginfo=None, + wcr_conflict=True): + """ Constructs a Memlet. + @param data: The data object or name to access. B{Note:} this + parameter will soon be deprecated. + @type data: Either a string of the data descriptor name or an + AccessNode. + @param num_accesses: The number of times that the moved data + will be subsequently accessed. If + `dace.types.DYNAMIC` (-1), + designates that the number of accesses is + unknown at compile time. + @param subset: The subset of `data` that is going to be accessed. + @param vector_length: The length of a single unit of access to + the data (used for vectorization + optimizations). + @param wcr: A lambda function specifying how write-conflicts + are resolved. The syntax of the lambda function receives two elements: `current` value and `new` value, + and returns the value after resolution. For example, + summation is `lambda cur, new: cur + new`. + @param wcr_identity: Identity value used for the first write + conflict. B{Note:} this parameter will soon + be deprecated. + @param other_subset: The reindexing of `subset` on the other + connected data. + @param debuginfo: Source-code information (e.g., line, file) + used for debugging. + @param wcr_conflict: If False, forces non-locked conflict + resolution when generating code. The default + is to let the code generator infer this + information from the SDFG. + """ + + # Properties + self.num_accesses = num_accesses # type: sympy math + self.subset = subset # type: subsets.Subset + self.veclen = vector_length # type: int (in elements, default 1) + if hasattr(data, 'data'): + data = data.data + self.data = data # type: str + + # Annotates memlet with _how_ writing is performed in case of conflict + self.wcr = wcr + self.wcr_identity = wcr_identity + self.wcr_conflict = wcr_conflict + + # The subset of the other endpoint we are copying from/to (note: + # carries the dimensionality of the other endpoint too!) + self.other_subset = other_subset + + self.debuginfo = debuginfo + + def toJSON(self, indent=0): + json = " " * indent + "{\n" + indent += 2 + json += " " * indent + "\"type\" : \"" + type(self).__name__ + "\",\n" + json += " " * indent + "\"label\" : \"" + str(self) + "\"\n" + indent -= 2 + json += " " * indent + "}\n" + return json + + @staticmethod + def simple(data, + subset_str, + veclen=1, + wcr_str=None, + wcr_identity=None, + other_subset_str=None, + wcr_conflict=True, + num_accesses=None, + debuginfo=None): + """ Constructs a Memlet from string-based expressions. + @param data: The data object or name to access. B{Note:} this + parameter will soon be deprecated. + @type data: Either a string of the data descriptor name or an + AccessNode. + @param subset_str: The subset of `data` that is going to + be accessed in string format. Example: '0:N'. + @param veclen: The length of a single unit of access to + the data (used for vectorization optimizations). + @param wcr_str: A lambda function (as a string) specifying + how write-conflicts are resolved. The syntax + of the lambda function receives two elements: + `current` value and `new` value, + and returns the value after resolution. For + example, summation is + `'lambda cur, new: cur + new'`. + @param wcr_identity: Identity value used for the first write + conflict. B{Note:} this parameter will soon + be deprecated. + @param other_subset_str: The reindexing of `subset` on the other + connected data (as a string). + @param wcr_conflict: If False, forces non-locked conflict + resolution when generating code. The default + is to let the code generator infer this + information from the SDFG. + @param num_accesses: The number of times that the moved data + will be subsequently accessed. If + `dace.types.DYNAMIC` (-1), + designates that the number of accesses is + unknown at compile time. + @param debuginfo: Source-code information (e.g., line, file) + used for debugging. + + """ + subset = SubsetProperty.from_string(subset_str) + if num_accesses is not None: + na = num_accesses + else: + na = subset.num_elements() + + if wcr_str is not None: + wcr = LambdaProperty.from_string(wcr_str) + else: + wcr = None + + if other_subset_str is not None: + other_subset = SubsetProperty.from_string(other_subset_str) + else: + other_subset = None + + # If it is an access node or another memlet + if hasattr(data, 'data'): + data = data.data + + return Memlet( + data, + na, + subset, + veclen, + wcr=wcr, + wcr_identity=wcr_identity, + other_subset=other_subset, + wcr_conflict=wcr_conflict, + debuginfo=debuginfo) + + @staticmethod + def from_array(dataname, datadesc): + """ Constructs a Memlet that transfers an entire array's contents. + @param dataname: The name of the data descriptor in the SDFG. + @param datadesc: The data descriptor object. + @type datadesc: Data. + """ + range = subsets.Range.from_array(datadesc) + return Memlet(dataname, range.num_elements(), range, 1) + + def __hash__(self): + return hash((self.data, self.num_accesses, self.subset, self.veclen, + str(self.wcr), self.wcr_identity, self.other_subset)) + + def __eq__(self, other): + return all([ + self.data == other.data, self.num_accesses == other.num_accesses, + self.subset == other.subset, self.veclen == other.veclen, + self.wcr == other.wcr, self.wcr_identity == other.wcr_identity, + self.other_subset == other.other_subset + ]) + + def num_elements(self): + """ Returns the number of elements in the Memlet subset. """ + return self.subset.num_elements() + + def bounding_box_size(self): + """ Returns a per-dimension upper bound on the maximum number of + elements in each dimension. + + This bound will be tight in the case of Range. + """ + return self.subset.bounding_box_size() + + def validate(self, sdfg, state): + if self.data not in sdfg.arrays: + raise KeyError('Array "%s" not found in SDFG' % self.data) + + def __label__(self, sdfg, state): + """ Returns a string representation of the memlet for display in a + graph. + + @param sdfg: The SDFG in which the memlet resides. + @param state: An SDFGState object in which the memlet resides. + """ + if self.data is None: + return self._label(None) + return self._label(sdfg.arrays[self.data].shape) + + def __str__(self): + return self._label(None) + + def _label(self, shape): + result = '' + if self.data is not None: + result = self.data + + if self.subset is None: + return result + + num_elements = self.subset.num_elements() + if self.num_accesses != num_elements: + result += '(%s) ' % str(self.num_accesses) + arrayNotation = True + try: + if shape is not None and reduce(operator.mul, shape, 1) == 1: + # Don't draw array if we're accessing a single element + arrayNotation = False + except TypeError: + # Will fail if trying to check the truth value of a sympy expr + pass + if arrayNotation: + result += '[%s]' % str(self.subset) + if self.wcr is not None and str(self.wcr) != '': + # Autodetect reduction type + redtype = detect_reduction_type(self.wcr) + if redtype == types.ReductionType.Custom: + wcrstr = unparse(ast.parse(self.wcr).body[0].value.body) + else: + wcrstr = str(redtype) + wcrstr = wcrstr[wcrstr.find('.') + 1:] # Skip "ReductionType." + + result += ' (CR: %s' % wcrstr + if self.wcr_identity is not None: + result += ', id: %s' % str(self.wcr_identity) + result += ')' + + if self.other_subset is not None: + result += ' -> [%s]' % str(self.other_subset) + return result + + def __repr__(self): + return "Memlet (" + self.__str__() + ")" + + +class EmptyMemlet(Memlet): + """ A memlet without data. Primarily used for connecting nodes to scopes + without transferring data to them. """ + + def __init__(self): + super(EmptyMemlet, self).__init__(None, 0, None, 1) diff --git a/dace/properties.py b/dace/properties.py new file mode 100644 index 0000000000..9dd11432fe --- /dev/null +++ b/dace/properties.py @@ -0,0 +1,846 @@ +import ast +import astunparse +from collections import OrderedDict +import copy +from dace.frontend.python.astutils import unparse +import itertools +import pydoc +import re +import sympy as sp +import numpy as np +import dace.subsets as sbs +import dace +from dace.symbolic import pystr_to_symbolic +from dace.types import DebugInfo + +############################################################################### +# External interface to guarantee correct usage +############################################################################### + + +def set_property_from_string(prop, obj, string, sdfg=None): + """ Interface function that guarantees that a property will always be + correctly set, if possible, by accepting all possible input arguments to + from_string. """ + + # If the property is a string (property name), obtain it from the object + if isinstance(prop, str): + prop = type(obj).__properties__[prop] + + if isinstance(prop, CodeProperty): + val = prop.from_string(string, obj.language) + elif isinstance(prop, (ReferenceProperty, DataProperty)): + if sdfg is None: + raise ValueError( + "You cannot pass sdfg=None when editing a ReferenceProperty!") + val = prop.from_string(string, sdfg) + else: + val = prop.from_string(string) + setattr(obj, prop.attr_name, val) + + +############################################################################### +# Property base implementation +############################################################################### + + +class PropertyError(Exception): + """Exception type for errors related to internal functionality of + these properties.""" + pass + + +class Property: + """ Class implementing properties of DaCe objects that conform to strong + typing, and allow conversion to and from strings to be edited. """ + + def __init__( + self, + getter=None, + setter=None, + dtype=None, + default=None, + from_string=None, + to_string=None, + enum=None, # Values must be present in this enum + unmapped=False, # Don't enforce 1:1 mapping with a member variable + allow_none=False, + indirected=False, # This property belongs to a different class + desc=""): + + self._getter = getter + self._setter = setter + self._dtype = dtype + self._default = default + if from_string is not None: + self._from_string = from_string + elif enum is not None: + self._from_string = lambda s: enum[s] + else: + self._from_string = self.dtype + if to_string is not None: + self._to_string = to_string + elif enum is not None: + self._to_string = lambda val: val._name_ + else: + self._to_string = str + self._enum = enum + self._unmapped = unmapped + self._allow_none = allow_none + self._indirected = indirected + self._desc = desc + + def __get__(self, obj, objtype=None): + if obj is None: + # Called on the class rather than an instance, so return the + # property object itself + return self + # If a custom getter is specified, use it + if self.getter: + return self.getter(obj) + if not hasattr(self, "attr_name"): + raise RuntimeError("Attribute name not set") + # Otherwise look for attribute prefixed by "_" + return getattr(obj, "_" + self.attr_name) + + def __set__(self, obj, val): + # If custom setter is specified, use it + if self.setter: + return self.setter(obj, val) + if not hasattr(self, "attr_name"): + raise RuntimeError("Attribute name not set") + # Fail on None unless explicitly allowed + if val is None and not self.allow_none: + raise ValueError( + "None not allowed for property {} in class {}".format( + self.attr_name, + type(obj).__name__)) + + # Accept all DaCe/numpy typeclasses as Python native types + if isinstance(val, np.number): + val = val.item() + + # Check if type matches before setting + if (self.dtype is not None and not isinstance(val, self.dtype) + and not (val is None and self.allow_none)): + if isinstance(val, str): + raise TypeError( + "Received str for property {} of type {}. Use " + "dace.properties.set_property_from_string or the " + "from_string method of the property.".format( + self.attr_name, self.dtype)) + raise TypeError( + "Invalid type \"{}\" for property {}: expected {}".format( + type(val).__name__, self.attr_name, self.dtype.__name__)) + # If the value has not yet been set, we cannot pass it to the enum + # function. Fail silently if this happens + if self.enum is not None and isinstance(self.enum, (list, tuple, set)): + if val not in self.enum: + raise ValueError("Value {} not present in enum: {}".format( + val, self.enum)) + setattr(obj, "_" + self.attr_name, val) + + # Property-ception >:-) + + @property + def getter(self): + return self._getter + + @getter.setter + def getter(self, val): + self._getter = val + + @property + def setter(self): + return self._setter + + @setter.setter + def setter(self, val): + self._setter = val + + @property + def dtype(self): + return self._dtype + + @property + def default(self): + return self._default + + @property + def allow_none(self): + return self._allow_none + + @property + def desc(self): + return self._desc + + @property + def from_string(self): + return self._from_string + + @property + def to_string(self): + return self._to_string + + @property + def enum(self): + return self._enum + + @property + def unmapped(self): + return self._unmapped + + @property + def indirected(self): + return self._indirected + + @indirected.setter + def indirected(self, val): + self._indirected = val + + +############################################################################### +# Decorator for objects with properties +############################################################################### + + +def _property_generator(instance): + for name, prop in type(instance).__properties__.items(): + yield prop, getattr(instance, name) + + +def make_properties(cls): + """ A decorator for objects that adds support and checks for strongly-typed + properties (which use the Property class). + """ + + # Extract all Property members of the class + properties = OrderedDict([(name, prop) + for name, prop in cls.__dict__.items() + if isinstance(prop, Property)]) + # Set the property name to its field name in the class + for name, prop in properties.items(): + prop.attr_name = name + # Grab properties from baseclass(es) + own_properties = copy.copy(properties) + for base in cls.__bases__: + if hasattr(base, "__properties__"): + duplicates = base.__properties__.keys() & own_properties.keys() + if len(duplicates) != 0: + raise AttributeError( + "Duplicate properties in class {} deriving from {}: {}". + format(cls.__name__, base.__name__, duplicates)) + properties.update(base.__properties__) + # Add the list of properties to the class + cls.__properties__ = properties + # Add an iterator to pairs of property names and their values + cls.properties = _property_generator + + # Grab old init. This will be brought into the closure in the below + init = cls.__init__ + + def initialize_properties(obj, *args, **kwargs): + # Set default values. If we don't do this, properties that depend on + # other might fail because the others rely on being set by a default + # value + for name, prop in own_properties.items(): + # Only assign our own properties, so we don't overwrite what's been + # set by the base class + if hasattr(obj, name): + raise PropertyError( + "Property {} already assigned in {}".format( + name, + type(obj).__name__)) + if not prop.indirected and prop.default is not None: + setattr(obj, name, prop.default) + # Now call vanilla __init__, which can initialize members + init(obj, *args, **kwargs) + # Assert that all properties have been set + for name, prop in properties.items(): + try: + getattr(obj, name) + except AttributeError: + if not prop.unmapped: + raise PropertyError( + "Property {} is unassigned in __init__ for {}".format( + name, cls.__name__)) + # Assert that there are no fields in the object not captured by + # properties, unless they are prefixed with "_" + for name, prop in obj.__dict__.items(): + if name not in properties and not name.startswith("_"): + raise PropertyError( + "{} : Variable {} is neither a Property nor " + "an internal variable (prefixed with \"_\")".format( + str(type(obj)), name)) + + # Replace the __init__ method + cls.__init__ = initialize_properties + + return cls + + +def indirect_property(cls, f, prop, override): + + # Make a copy of the original property, but override its getter and setter + prop_name = prop.attr_name + prop_indirect = copy.copy(prop) + prop_indirect.indirected = True + + # Because this is a separate function, prop_name is caught in the closure + def indirect_getter(obj): + return getattr(f(obj), prop_name) + + def indirect_setter(obj, val): + return setattr(f(obj), prop_name, val) + + prop_indirect.getter = indirect_getter + prop_indirect.setter = indirect_setter + + # Add the property to the class + if not override and hasattr(cls, prop_name): + raise TypeError( + "Property \"{}\" already exists in class \"{}\"".format( + prop_name, cls.__name__)) + setattr(cls, prop_name, prop_indirect) + + +def indirect_properties(indirect_class, indirect_function, override=False): + """ A decorator for objects that provides indirect properties defined + in another class. + """ + + def indirection(cls): + # For every property in the class we are indirecting to, create an + # indirection property in this class + for prop in indirect_class.__properties__.values(): + indirect_property(cls, indirect_function, prop, override) + return make_properties(cls) + + return indirection + + +############################################################################### +# Custom properties +############################################################################### + + +# TODO: does not currently work because of how enums work +class OrderProperty(Property): + """ Custom property class that handles the mapping between the order + property and the actual class fields (range and parameters). """ + + # This is implemented in the context of dace.nodes.Map, but could in + # principle be reused for other objects, assuming they set the internal + # fields "_range" and "_params". + + def __get__(self, obj, objtype=None): + # Copy to avoid changes in the list at callee to be reflected in + # the map directly + return list(obj._params) + + def __set__(self, obj, val): + """ Update both params and ranges based on the new order. """ + # Make this more lenient to the input by comparing strings, and + # using the new order to shuffle the original lists + param_strings = list(map(str, obj._params)) + update_strings = list(map(str, val)) + if len(update_strings) != len(param_strings): + raise ValueError( + "Wrong length of new order: {} (found {}, expected {})".format( + str(val), len(update_strings), len(param_strings))) + # The below will throw a ValueError if a parameter doesn't exist + # We assume that no parameter will be present twice... + indices = [param_strings.index(x) for x in update_strings] + obj._params = [obj._params[i] for i in indices] + obj._range.reorder(indices) + + @staticmethod + def to_string(val): + return "({})".format(", ".join(map(str, val))) + + @staticmethod + def from_string(s): + """Create a list of symbols from a list of strings.""" + return [sp.Symbol(i) for i in re.sub("[\(\)\[\]]", "", s).split(",")] + + @staticmethod + def enum(obj): + """Implement enum to populate e.g. dropdown.""" + return list(itertools.permutations(obj)) + + +class RangeProperty(Property): + """ Custom Property type for `dace.graph.subset.Range` members. """ + + def __set__(self, obj, value): + if isinstance(value, list): + value = dace.subsets.Range(value) + super(RangeProperty, self).__set__(obj, value) + + @property + def dtype(self): + return sbs.Range + + @staticmethod + def to_string(obj): + return sbs.Range.ndslice_to_string(obj) + + @staticmethod + def from_string(s): + return sbs.Range.from_string(s) + + +class DebugInfoProperty(Property): + """ Custom Property type for DebugInfo members. """ + + @property + def dtype(self): + return DebugInfo + + @property + def allow_none(self): + return True + + @staticmethod + def to_string(di): + if isinstance(di, DebugInfo): + r = "file:" + str(di.filename) + " " + r += "from line: " + str(di.start_line) + " col: " + str( + di.start_column) + " " + r += "to line: " + str(di.end_line) + " col: " + str(di.end_column) + return r + else: + return "None" + + @staticmethod + def from_string(s): + f = None + sl = 0 + el = 0 + sc = 0 + ec = 0 + info_available = False + di = None + + m = re.search("file: (\w+)", s) + if m is not None: + info_available = True + f = sl = m.group(1) + m = re.search("from line: (\d+)", s) + if m is not None: + sl = m.group(1) + el = sl + info_available = True + m = re.search("to line: (\d+)", s) + if m is not None: + el = m.group(1) + info_available = True + m = re.search("from col: (\d+)", s) + if m is not None: + sc = m.group(1) + ec = sc + info_available = True + m = re.search("to col: (\d+)", s) + if m is not None: + ec = m.group(1) + info_available = True + if info_available: + di = DebugInfo(f, sl, sc, el, ec) + return di + + +class ParamsProperty(Property): + """ Property for list of parameters, such as parameters for a Map. """ + + @property + def dtype(self): + return list + + @staticmethod + def to_string(l): + return "[{}]".format(", ".join(map(str, l))) + + @staticmethod + def from_string(s): + return [ + sp.Symbol(m.group(0)) + for m in re.finditer("[a-zA-Z_][a-zA-Z0-9_]*", s) + ] + + +class SetProperty(Property): + """Property for a set of elements of one type, e.g., connectors. """ + + def __init__( + self, + element_type, + getter=None, + setter=None, + default=None, + from_string=None, + to_string=None, + unmapped=False, # Don't enforce 1:1 mapping with a member variable + allow_none=False, + desc=""): + super(SetProperty, self).__init__( + getter=getter, + setter=setter, + dtype=set, + default=default, + from_string=from_string, + to_string=to_string, + enum=None, + unmapped=unmapped, + allow_none=allow_none, + desc=desc) + self._element_type = element_type + + @property + def dtype(self): + return set + + @staticmethod + def to_string(l): + return str(l) + + @staticmethod + def from_string(s): + return [eval(i) for i in re.sub("[\{\}\(\)\[\]]", "", s).split(",")] + + def __get__(self, obj, objtype=None): + # Copy to avoid changes in the set at callee to be reflected in + # the node directly + return set(super(SetProperty, self).__get__(obj, objtype)) + + def __set__(self, obj, val): + # Check for uniqueness + if len(val) != len(set(val)): + dups = set([x for x in val if val.count(x) > 1]) + raise ValueError('Duplicates found in set: ' + str(dups)) + # Cast to element type + try: + new_set = set(self._element_type(elem) for elem in val) + except (TypeError, ValueError): + raise ValueError('Some elements could not be converted to %s' % + (str(self._element_type))) + + super(SetProperty, self).__set__(obj, new_set) + + +class LambdaProperty(Property): + """ Custom Property type that accepts a lambda function, with conversions + to and from strings. """ + + @property + def dtype(self): + return str + + @staticmethod + def from_string(s): + return ast.parse(s).body[0].value + + @staticmethod + def to_string(obj): + if obj is None: + return 'lambda: None' + if isinstance(obj, str): + return obj + return unparse(obj) + + def __set__(self, obj, val): + if val is not None: + if isinstance(val, str): + self.from_string(val) # Check that from_string doesn't fail + elif isinstance(val, ast.Lambda): + val = self.to_string(val) # Store as string internally + else: + raise TypeError( + "Lambda property must be either string or ast.Lambda") + super(LambdaProperty, self).__set__(obj, val) + + +class CodeBlock(list): + """ Helper class that represents AST code blocks for `CodeProperty`, + implemented as a list with an extra _as_string property. The object + also stores the original string, allowing us to preserve comments and + formatting from user input. + """ + + def __init__(self, *args, **kwargs): + self._as_string = "" + super().__init__(*args, **kwargs) + + @property + def as_string(self): + return self._as_string + + @as_string.setter + def as_string(self, string): + self._as_string = string + + +class CodeProperty(Property): + """ Custom Property type that accepts code in various languages. """ + + @property + def dtype(self): + return None + + @staticmethod + def from_string(string, language=None): + if language is None: + raise TypeError("Must pass language as second argument to " + "from_string method of CodeProperty") + if language == dace.types.Language.Python: + block = CodeBlock(ast.parse(string).body) + block.as_string = string + return block + else: + # Do nothing for now + return string + + @staticmethod + def to_string(obj): + if isinstance(obj, str): + return obj + # Grab the originally parsed string if any + if obj._as_string is not None and obj._as_string != "": + return obj._as_string + # It's probably good enough to assume that there is an original string + # if the language was not Python, so we just throw the string to the + # astunparser. + return unparse(obj) + + def __set__(self, obj, val): + # Check if the class has a language property + if not hasattr(type(obj), "language"): + raise AttributeError( + "Class \"{}\" with a CodeProperty field must also " + "have a \"language\" attribute.".format(type(obj).__name__)) + # Check if the object has a language attribute + try: + language = obj.language + except AttributeError: + # Language exists as an attribute, but has not yet been set. Accept + # this, because __dict__ is not guaranteed to be in the order that + # the attributes are defined in. + language = None + if val is None: + # Keep as None. The "allow_none" check in the superclass + # ensures that this is legal + pass + elif isinstance(val, str): + if language is not None: + # Store original string + val = self.from_string(val, language) + else: + try: + if language is not dace.types.Language.Python: + raise TypeError("Only strings accepted for other " + "languages than Python.") + except AttributeError: + # Don't check language if it has not been set yet. We will + # assume it's Python AST, since it wasn't a string + pass + if isinstance(val, (ast.FunctionDef, ast.With)): + # TODO: the original parsing should have already stripped this + val = CodeBlock(val.body) + elif isinstance(val, ast.AST): + val = CodeBlock([val]) + else: + try: + iter(val) + except TypeError: + raise TypeError( + "CodeProperty expected an iterable of expressions, " + " got {}".format(type(val).__name__)) + for e in val: + if not isinstance(e, ast.AST): + raise TypeError( + "Found type {} in list of AST expressions: " + "expected ast.AST".format(type(e).__name__)) + super(CodeProperty, self).__set__(obj, val) + + +class SubsetProperty(Property): + """ Custom Property type that accepts any form of subset, and enables + parsing strings into multiple types of subsets. """ + + @property + def dtype(self): + return None + + @property + def allow_none(self): + return True + + def __set__(self, obj, val): + if (val is not None and not isinstance(val, sbs.Range) + and not isinstance(val, sbs.Indices)): + try: + val = self.from_string(val) + except SyntaxError: + raise TypeError( + "Subset property must be either Range or Indices: got {}". + format(type(val).__name__)) + super(SubsetProperty, self).__set__(obj, val) + + @staticmethod + def from_string(s): + if s is None or s == 'None' or len(s) == 0: + return None + ranges = sbs.Range.from_string(s) + if ranges: + return ranges + else: + return sbs.Indices.from_string(s) + + @staticmethod + def to_string(val): + if isinstance(val, sbs.Range): + return sbs.Range.ndslice_to_string(val) + elif isinstance(val, sbs.Indices): + return sbs.Indices.__str__(val) + elif val is None: + return 'None' + raise TypeError + + +class SymbolicProperty(Property): + """ Custom Property type that accepts integers or Sympy expressions. """ + + @property + def dtype(self): + return None + + def __set__(self, obj, val): + if (not isinstance(val, sp.expr.Expr) and not isinstance(val, int) + and not isinstance(val, str)): + raise TypeError( + "Property {} must an int or symbolic expression".format( + self.attr_name)) + super(SymbolicProperty, self).__set__(obj, val) + + @staticmethod + def from_string(s): + return pystr_to_symbolic(s) + + +class DataProperty(Property): + """ Custom Property type that represents a link to a data descriptor. + Needs the SDFG to be passed as an argument to `from_string` and + `enum`. """ + + def __init__(self, desc='', default=None): + # Data can be None when no data is flowing, e.g., on a memlet with a + # map that has no external inputs + return super().__init__( + dtype=str, allow_none=True, desc=desc, default=default) + + @staticmethod + def enum(sdfg=None): + if sdfg is None: + raise TypeError("Must pass SDFG as second argument to " + "enum method of ArrayProperty") + return list(sdfg.arrays.keys()) + + @staticmethod + def from_string(s, sdfg=None): + if sdfg is None: + raise TypeError("Must pass SDFG as second argument to " + "from_string method of ArrayProperty") + if s not in sdfg.arrays: + raise ValueError("No data found in SDFG with name: {}".format(s)) + return s + + @staticmethod + def to_string(obj): + return str(obj) + + +class ReferenceProperty(Property): + """ Custom Property type that represents a link to another SDFG object. + Needs the SDFG to be passed as an argument to `from_string`.""" + + @staticmethod + def from_string(s, sdfg=None): + if sdfg is None: + raise TypeError("Must pass SDFG as second argument to " + "from_string method of ReferenceProperty") + for node in sdfg.states(): + if node.label == s: + return node + for node, _ in sdfg.all_nodes_recursive(): + if node.label == s: + return node + raise ValueError("No node found in SDFG with name: {}".format(s)) + + @staticmethod + def to_string(obj): + return obj.label + + +class ShapeProperty(Property): + """ Custom Property type that defines a shape. """ + + @property + def dtype(self): + return tuple + + @staticmethod + def from_string(s): + if s[0] == "(" and s[-1] == ")": + s = s[1:-1] + return tuple([ + dace.symbolic.pystr_to_symbolic(m.group(0)) + for m in re.finditer("[^,;:]+", s) + ]) + + @staticmethod + def to_string(obj): + return ", ".join(map(str, obj)) + + def __set__(self, obj, val): + if isinstance(val, list): + val = tuple(val) + super(ShapeProperty, self).__set__(obj, val) + + +class TypeProperty(Property): + """ Custom Property type that finds a type according to the input string. + """ + + @property + def dtype(self): + return type + + # TODO: this does not work both ways! If converted to a string we lose the + # location information. + @staticmethod + def from_string(s): + dtype = pydoc.locate(s) + if dtype is None: + raise ValueError("No type \"{}\" found.".format(s)) + if not isinstance(dtype, type): + raise ValueError("Object \"{}\" is not a type.".format(dtype)) + return dtype + + +class TypeClassProperty(Property): + """ Custom property type for memory as defined in dace.types, + e.g. `dace.float32`. """ + + @property + def dtype(self): + return dace.types.typeclass + + @staticmethod + def from_string(s): + dtype = pydoc.locate("dace.types.{}".format(s)) + if dtype is None or not isinstance(dtype, dace.types.typeclass): + raise ValueError("Not a valid data type: {}".format(s)) + return dtype + + @staticmethod + def to_string(obj): + return obj.to_string() diff --git a/dace/runtime/include/dace/complex.h b/dace/runtime/include/dace/complex.h new file mode 100644 index 0000000000..f220c347b6 --- /dev/null +++ b/dace/runtime/include/dace/complex.h @@ -0,0 +1,63 @@ +#ifndef __DACE_COMPLEX_H +#define __DACE_COMPLEX_H + +#include + +#ifdef __CUDACC__ + #include + #define dace_conj thrust::conj +#else + #define dace_conj std::conj +#endif + +// Contains a complex-j class to support the native complex type in Python + +namespace dace +{ + struct complexJ + { + int val; + explicit complexJ(int v = 1) : val(v) {} + }; + + static inline int operator*(const complexJ& j1, const complexJ& j2) + { + return -j1.val * j2.val; + } + template + std::complex operator*(const complexJ& j, const T& other) + { + return std::complex(T(0), j.val * other); + } + template + std::complex operator*(const T& other, const complexJ& j) + { + return std::complex(T(0), j.val * other); + } + static inline complexJ operator*(const int& other, const complexJ& j) + { + return complexJ(j.val * other); + } + static inline complexJ operator*(const complexJ& j, const int& other) + { + return complexJ(j.val * other); + } + static inline complexJ operator-(const complexJ& j) + { + return complexJ(-j.val); + } +} + + +// Complex-scalar multiplication functions + +template +std::complex operator*(const std::complex& a, const int& b) { + return std::complex(b*a.real(), b*a.imag()); +} +template +std::complex operator*(const int& a, const std::complex& b) { + return std::complex(a*b.real(), a*b.imag()); +} + +#endif // __DACE_COMPLEX_H diff --git a/dace/runtime/include/dace/copy.h b/dace/runtime/include/dace/copy.h new file mode 100644 index 0000000000..6ae7016361 --- /dev/null +++ b/dace/runtime/include/dace/copy.h @@ -0,0 +1,267 @@ +#ifndef __DACE_COPY_H +#define __DACE_COPY_H + +#include "types.h" +#include "vector.h" + +namespace dace +{ + template + inline void InitArray(T *ptr, const U& value, int size) + { + for (int i = 0; i < size; ++i) + *ptr++ = T(value); + } + + template + struct CopyND; + template + struct CopyNDDynamic; + + template + struct CopyND + { + template + struct ConstSrc + { + template + static DACE_HDFI void Copy(const T *src, T *dst, const int& dst_stride, const Args&... dst_otherdims) + { +#ifndef __CUDA_ARCH__ + // Memcpy specialization + if (sizeof...(OTHER_COPYDIMS) == 0 && SRC_STRIDE == 1 && dst_stride == 1) { + memcpy(dst, src, COPYDIM * sizeof(T) * VECLEN); + return; + } +#endif + + __DACE_UNROLL + for (int i = 0; i < COPYDIM; ++i) + CopyND::template ConstSrc::Copy( + src + i * SRC_STRIDE, dst + i * dst_stride, dst_otherdims...); + } + + template + static DACE_HDFI void Accumulate(const T *src, T *dst, ACCUMULATE acc, const int& dst_stride, const Args&... dst_otherdims) + { + __DACE_UNROLL + for (int i = 0; i < COPYDIM; ++i) + CopyND::template ConstSrc::Accumulate( + src + i * SRC_STRIDE, dst + i * dst_stride, acc, dst_otherdims...); + } + }; + + template + struct ConstDst + { + template + static DACE_HDFI void Copy(const T *src, T *dst, const int& src_stride, const Args&... src_otherdims) + { +#ifndef __CUDA_ARCH__ + // Memcpy specialization + if (sizeof...(OTHER_COPYDIMS) == 0 && src_stride == 1 && DST_STRIDE == 1) { + memcpy(dst, src, COPYDIM * sizeof(T) * VECLEN); + return; + } +#endif + + __DACE_UNROLL + for (int i = 0; i < COPYDIM; ++i) + CopyND::template ConstDst::Copy( + src + i * src_stride, dst + i * DST_STRIDE, src_otherdims...); + } + + template + static DACE_HDFI void Accumulate(const T *src, T *dst, ACCUMULATE acc, const int& src_stride, const Args&... src_otherdims) + { + __DACE_UNROLL + for (int i = 0; i < COPYDIM; ++i) + CopyND::template ConstDst::Accumulate( + src + i * src_stride, dst + i * DST_STRIDE, acc, src_otherdims...); + } + }; + }; + + // Specialization for actual copy / accumulation + template + struct CopyND + { + template + struct ConstSrc + { + static DACE_HDFI void Copy(const T *src, T *dst) + { + *(vec *)dst = *(vec *)src; + } + + template + static DACE_HDFI void Accumulate(const T *src, T *dst, ACCUMULATE acc) + { + *(vec *)dst = acc(*(vec *)dst, *(vec *)src); + } + }; + + template + struct ConstDst + { + static DACE_HDFI void Copy(const T *src, T *dst) + { + *(vec *)dst = *(vec *)src; + } + + template + static DACE_HDFI void Accumulate(const T *src, T *dst, ACCUMULATE acc) + { + *(vec *)dst = acc(*(vec *)dst, *(vec *)src); + } + }; + }; + + template + struct CopyNDDynamic + { + template + struct ConstSrc + { + template + static DACE_HDFI void Copy(const T *src, T *dst, const int& copydim, const int& dst_stride, const Args&... otherdims) + { +#ifndef __CUDA_ARCH__ + // Memcpy specialization + if (N == 1 && SRC_STRIDE == 1 && dst_stride == 1) { + memcpy(dst, src, copydim * sizeof(T) * VECLEN); + return; + } +#endif + + __DACE_UNROLL + for (int i = 0; i < copydim; ++i) + CopyNDDynamic::template ConstSrc::Copy( + src + i * SRC_STRIDE, dst + i * dst_stride, otherdims...); + } + + template + static DACE_HDFI void Accumulate(const T *src, T *dst, ACCUMULATE acc, const int& copydim, const int& dst_stride, const Args&... otherdims) + { + __DACE_UNROLL + for (int i = 0; i < copydim; ++i) + CopyNDDynamic::template ConstSrc::Accumulate( + src + i * SRC_STRIDE, dst + i * dst_stride, acc, otherdims...); + } + }; + + template + struct ConstDst + { + template + static DACE_HDFI void Copy(const T *src, T *dst, const int& copydim, const int& src_stride, const Args&... otherdims) + { +#ifndef __CUDA_ARCH__ + // Memcpy specialization + if (N == 1 && src_stride == 1 && DST_STRIDE == 1) { + memcpy(dst, src, copydim * sizeof(T) * VECLEN); + return; + } +#endif + + __DACE_UNROLL + for (int i = 0; i < copydim; ++i) + CopyNDDynamic::template ConstDst::Copy( + src + i * src_stride, dst + i * DST_STRIDE, otherdims...); + } + + template + static DACE_HDFI void Accumulate(const T *src, T *dst, ACCUMULATE acc, const int& copydim, const int& src_stride, const Args&... otherdims) + { + __DACE_UNROLL + for (int i = 0; i < copydim; ++i) + CopyNDDynamic::template ConstDst::Accumulate( + src + i * src_stride, dst + i * DST_STRIDE, acc, otherdims...); + } + }; + + struct Dynamic + { + template + static DACE_HDFI void Copy(const T *src, T *dst, const int& copydim, const int& src_stride, const int& dst_stride, const Args&... otherdims) + { + static_assert(sizeof...(otherdims) == (N - 1) * 3, "Dimensionality mismatch in dynamic copy"); + +#ifndef __CUDA_ARCH__ + // Memcpy specialization + if (N == 1 && src_stride == 1 && dst_stride == 1) { + memcpy(dst, src, copydim * sizeof(T) * VECLEN); + return; + } +#endif + + __DACE_UNROLL + for (int i = 0; i < copydim; ++i) + CopyNDDynamic::Dynamic::Copy( + src + i * src_stride, dst + i * dst_stride, otherdims...); + } + + template + static DACE_HDFI void Accumulate(const T *src, T *dst, ACCUMULATE acc, const int& copydim, const int& src_stride, const int& dst_stride, const Args&... otherdims) + { + static_assert(sizeof...(otherdims) == (N - 1) * 3, "Dimensionality mismatch in dynamic copy"); + __DACE_UNROLL + for (int i = 0; i < copydim; ++i) + CopyNDDynamic::Dynamic::Accumulate( + src + i * src_stride, dst + i * dst_stride, acc, otherdims...); + } + }; + }; + + template + struct CopyNDDynamic + { + template + struct ConstSrc + { + static DACE_HDFI void Copy(const T *src, T *dst) + { + *(vec *)dst = *(vec *)src; + } + + template + static DACE_HDFI void Accumulate(const T *src, T *dst, ACCUMULATE acc) + { + *(vec *)dst = acc(*(vec *)dst, *(vec *)src); + } + }; + + template + struct ConstDst + { + static DACE_HDFI void Copy(const T *src, T *dst) + { + *(vec *)dst = *(vec *)src; + } + + template + static DACE_HDFI void Accumulate(const T *src, T *dst, ACCUMULATE acc) + { + *(vec *)dst = acc(*(vec *)dst, *(vec *)src); + } + }; + + struct Dynamic + { + static DACE_HDFI void Copy(const T *src, T *dst) + { + *(vec *)dst = *(vec *)src; + } + + template + static DACE_HDFI void Accumulate(const T *src, T *dst, ACCUMULATE acc) + { + *(vec *)dst = acc(*(vec *)dst, *(vec *)src); + } + }; + }; + +} // namespace dace + +#endif // __DACE_COPY_H diff --git a/dace/runtime/include/dace/cuda/copy.cuh b/dace/runtime/include/dace/cuda/copy.cuh new file mode 100644 index 0000000000..be989bb66e --- /dev/null +++ b/dace/runtime/include/dace/cuda/copy.cuh @@ -0,0 +1,819 @@ +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the names of the copyright holders nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +#ifndef __DACE_CUDACOPY_CUH +#define __DACE_CUDACOPY_CUH + +#include +#include "../types.h" +#include "../vector.h" +#include "../reduction.h" + +namespace dace +{ + // Adapted from "MAPS: GPU Optimization and Memory Abstraction Framework" + // https://github.com/maps-gpu/MAPS + + // Converts from an integral amount of bytes to a type. + template + struct BytesToType + { + typedef void type; + }; + + #ifdef __DACE_BYTES_TO_TYPE + #error Using disallowed macro name __DACE_BYTES_TO_TYPE + #endif + + #define __DACE_BYTES_TO_TYPE(bytes, t) \ + template<> \ + struct BytesToType \ + { \ + typedef t type; \ + } + + __DACE_BYTES_TO_TYPE(16, float4); + __DACE_BYTES_TO_TYPE(8, uint64_t); + __DACE_BYTES_TO_TYPE(4, uint32_t); + __DACE_BYTES_TO_TYPE(2, uint16_t); + __DACE_BYTES_TO_TYPE(1, uint8_t); + + # undef __DACE_BYTES_TO_TYPE + + template + struct LinearizeTID + { + static DACE_DFI unsigned int get() + { + return threadIdx.x + threadIdx.y * BLOCK_WIDTH + + threadIdx.z * BLOCK_WIDTH * BLOCK_HEIGHT; + } + }; + + template + struct LinearizeTID + { + static DACE_DFI unsigned int get() + { + return threadIdx.x + threadIdx.y * BLOCK_WIDTH; + } + }; + + template + struct LinearizeTID + { + static DACE_DFI unsigned int get() + { + return threadIdx.x; + } + }; + + template + static DACE_DFI unsigned int GetLinearTID() { + return LinearizeTID::get(); + } + + //////////////////////////////////////////////////////////////////////// + // Detect optimal bit read preference + + enum + { + #if __CUDA_ARCH__ >= 500 + PREFERRED_GREAD_SIZE = 128 / 8, // 128-bit + PREFERRED_SWRITE_SIZE = 128 / 8, // 128-bit + #elif __CUDA_ARCH__ >= 300 + PREFERRED_GREAD_SIZE = 128 / 8, // 128-bit + PREFERRED_SWRITE_SIZE = 64 / 8, // 64-bit + #elif __CUDA_ARCH__ >= 130 + PREFERRED_GREAD_SIZE = 64 / 8, // 64-bit + PREFERRED_SWRITE_SIZE = 32 / 8, // 32-bit + #else + PREFERRED_GREAD_SIZE = 32 / 8, // Default to 32-bit loads + PREFERRED_SWRITE_SIZE = 32 / 8, // 32-bit + #endif + }; + + #define DEBUG_PRINT(...) do {} while(0) + #define BLOCK_PRINT(...) do {} while(0) + + //#define DEBUG_PRINT(...) do { if(threadIdx.x + threadIdx.y == 0 && blockIdx.x + blockIdx.y + blockIdx.z == 0 && threadIdx.z == 1) printf(__VA_ARGS__); } while(0) + //#define BLOCK_PRINT(...) do { if(blockIdx.x + blockIdx.y + blockIdx.z == 0) printf(__VA_ARGS__); } while(0) + + template + static DACE_DFI void GlobalToShared3D( + const T *ptr, int src_zstride, + int src_ystride, int src_xstride, T *smem) + { + // Linear thread ID + int ltid = GetLinearTID(); + + constexpr int BLOCK_SIZE = BLOCK_WIDTH * BLOCK_HEIGHT * BLOCK_DEPTH; + constexpr int TOTAL_XYZ = COPY_XLEN * COPY_YLEN * COPY_ZLEN; + constexpr int TOTAL_XY = COPY_XLEN * COPY_YLEN; + constexpr int XY_SLICES = BLOCK_SIZE / TOTAL_XY; + constexpr int XY_REM = BLOCK_SIZE % TOTAL_XY; + constexpr int X_SLICES = BLOCK_SIZE / COPY_XLEN; + constexpr int X_REM = BLOCK_SIZE % COPY_XLEN; + + ////////////////////////////////////////////////////////////////////// + // Block size larger than number of elements, one read + if ((BLOCK_SIZE / TOTAL_XYZ) > 0) + { + DEBUG_PRINT("Chose path XYZ\n"); + + // De-linearize + int ltidx = ltid % COPY_XLEN; + int ltidy = (ltid / COPY_XLEN) % COPY_YLEN; + int ltidz = (ltid / COPY_XLEN) / COPY_YLEN; + + if (ltid < TOTAL_XYZ) + { + smem[ltidx*DST_XSTRIDE + ltidy * DST_YSTRIDE + ltidz * DST_ZSTRIDE] = + *(ptr + ltidx * src_xstride + + src_ystride * ltidy + + src_zstride * ltidz); + } + } + + ////////////////////////////////////////////////////////////////////// + // More than one XY slice + else if ((BLOCK_SIZE / TOTAL_XYZ) == 0 && XY_SLICES > 0 && XY_REM > 0) + { + DEBUG_PRINT("Chose path XY.1\n"); + + // Currently, only use threads in slice + // TODO(later): If contiguous (DST_YSTRIDE == COPY_XLEN), use the rest + constexpr int SLICES_PER_ITER = (XY_SLICES == 0 ? 1 : // Compilers are annoying + COPY_ZLEN / XY_SLICES); + constexpr int REMAINDER = (XY_SLICES == 0 ? 1 : // Compilers are annoying + COPY_ZLEN % XY_SLICES); + constexpr int REMOFF = SLICES_PER_ITER * XY_SLICES; + + if (ltid < (BLOCK_SIZE - XY_REM)) + { + // De-linearize + int ltidx = ltid % COPY_XLEN; + int ltidy = (ltid / COPY_XLEN) % COPY_YLEN; + int ltidz = (ltid / COPY_XLEN) / COPY_YLEN; + + #pragma unroll + for (int i = 0; i < SLICES_PER_ITER; ++i) + { + smem[ltidx * DST_XSTRIDE + ltidy * DST_YSTRIDE + (i*XY_SLICES + ltidz) * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * ltidy + + src_zstride * (i * XY_SLICES + ltidz)); + } + + if (ltidz < REMAINDER) + { + // Read remainder + smem[ltidx * DST_XSTRIDE + ltidy * DST_YSTRIDE + (REMOFF + ltidz) * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * ltidy + + src_zstride * (REMOFF + ltidz)); + } + } + } + + ////////////////////////////////////////////////////////////////////// + // Exactly n*XY slices + else if ((BLOCK_SIZE / TOTAL_XYZ) == 0 && XY_SLICES > 0 && XY_REM == 0) + { + DEBUG_PRINT("Chose path XY.2\n"); + + constexpr int SLICES_PER_ITER = (XY_SLICES == 0 ? 1 : // Compilers are annoying + COPY_ZLEN / XY_SLICES); + constexpr int REMAINDER = (XY_SLICES == 0 ? 1 : // Compilers are annoying + COPY_ZLEN % XY_SLICES); + constexpr int REMOFF = SLICES_PER_ITER * XY_SLICES; + + // De-linearize + int ltidx = ltid % COPY_XLEN; + int ltidy = (ltid / COPY_XLEN) % COPY_YLEN; + int ltidz = (ltid / COPY_XLEN) / COPY_YLEN; + + #pragma unroll + for (int i = 0; i < SLICES_PER_ITER; ++i) + { + smem[ltidx * DST_XSTRIDE + ltidy * DST_YSTRIDE + (i*XY_SLICES + ltidz) * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * ltidy + + src_zstride * (i * XY_SLICES + ltidz)); + } + + if (ltidz < REMAINDER) + { + // Read remainder + smem[ltidx * DST_XSTRIDE + ltidy * DST_YSTRIDE + (REMOFF + ltidz) * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * ltidy + + src_zstride * (REMOFF + ltidz)); + } + } + + ////////////////////////////////////////////////////////////////////// + // More than X row + else if (XY_SLICES == 0 && X_SLICES > 0 && X_REM > 0) + { + DEBUG_PRINT("Chose path X.1\n"); + + // Currently, only use threads in row + // TODO(later): If contiguous (DST_YSTRIDE == COPY_XLEN), use the rest + constexpr int ROWS_PER_XY_SLICE = (X_SLICES == 0 ? 1 : // Compilers are annoying + COPY_YLEN / X_SLICES); + constexpr int REMAINDER = (X_SLICES == 0 ? 1 : // Compilers are annoying + COPY_YLEN % X_SLICES); + constexpr int REMOFF = ROWS_PER_XY_SLICE * X_SLICES; + + if (ltid < (BLOCK_SIZE - X_REM)) + { + // De-linearize + int ltidx = ltid % COPY_XLEN; + int ltidy = ltid / COPY_XLEN; + + #pragma unroll + for (int i = 0; i < COPY_ZLEN; ++i) + { + #pragma unroll + for (int j = 0; j < ROWS_PER_XY_SLICE; ++j) + { + smem[ltidx * DST_XSTRIDE + (j*X_SLICES + ltidy) * DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * (j * X_SLICES + ltidy) + + src_zstride * i); + } + + if (ltidy < REMAINDER) + { + // Read remainder + smem[ltidx * DST_XSTRIDE + (REMOFF + ltidy)* DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * (REMOFF + ltidy) + + src_zstride * i); + } + + } + } + } + + ////////////////////////////////////////////////////////////////////// + // Exactly n*X rows + else if (XY_SLICES == 0 && X_SLICES > 0 && X_REM == 0) + { + DEBUG_PRINT("Chose path X.2\n"); + + constexpr int ROWS_PER_XY_SLICE = (X_SLICES == 0 ? 1 : // Compilers are annoying + COPY_YLEN / X_SLICES); + constexpr int REMAINDER = (X_SLICES == 0 ? 1 : // Compilers are annoying + COPY_YLEN % X_SLICES); + constexpr int REMOFF = ROWS_PER_XY_SLICE * X_SLICES; + + // De-linearize + int ltidx = ltid % COPY_XLEN; + int ltidy = ltid / COPY_XLEN; + + #pragma unroll + for (int i = 0; i < COPY_ZLEN; ++i) + { + #pragma unroll + for (int j = 0; j < ROWS_PER_XY_SLICE; ++j) + { + smem[ltidx * DST_XSTRIDE + (j*X_SLICES + ltidy) * DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * (j * X_SLICES + ltidy) + + src_zstride * i); + } + + if (ltidy < REMAINDER) + { + // Read remainder + smem[ltidx * DST_XSTRIDE + (REMOFF + ltidy) * DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * (REMOFF + ltidy) + + src_zstride * i); + } + + } + } + + ////////////////////////////////////////////////////////////////////// + // Less than one X row + else if (X_SLICES == 0) + { + DEBUG_PRINT("Chose path X.3\n"); + + + constexpr int ITERATIONS_PER_ROW = COPY_XLEN / BLOCK_SIZE; + constexpr int REMAINDER = COPY_XLEN % BLOCK_SIZE; + constexpr int REMOFF = ITERATIONS_PER_ROW * BLOCK_SIZE; + + #pragma unroll + for (int i = 0; i < COPY_ZLEN; ++i) + { + #pragma unroll + for (int j = 0; j < COPY_YLEN; ++j) + { + #pragma unroll + for (int k = 0; k < ITERATIONS_PER_ROW; ++k) + { + smem[(k * BLOCK_SIZE + ltid) * DST_XSTRIDE + j * DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * (k * BLOCK_SIZE + ltid) + + src_ystride * j + + src_zstride * i); + } + + if (ltid < REMAINDER) + { + // Read remainder + smem[(REMOFF + ltid) * DST_ZSTRIDE + j * DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * (REMOFF + ltid) + + src_ystride * j + + src_zstride * i); + } + } + } + } + + ////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////// + + + if (!ASYNC) + __syncthreads(); + } + + template + static DACE_DFI void GlobalToShared1D( + const T *ptr, int src_xstride, T *smem) + { + GlobalToShared3D( + ptr, 1, 1, src_xstride, smem); + } + + template + static DACE_DFI void GlobalToShared2D( + const T *ptr, int src_ystride, int src_xstride, + T *smem) + { + GlobalToShared3D( + ptr, 1, src_ystride, src_xstride, smem); + } + + template + static DACE_DFI void GlobalToShared3DDynamic( + const T *ptr, int src_zstride, + int src_ystride, int src_xstride, T *smem, + int COPY_ZLEN, int COPY_YLEN, int COPY_XLEN) + { + // Linear thread ID + int ltid = GetLinearTID(); + + int BLOCK_SIZE = BLOCK_WIDTH * BLOCK_HEIGHT * BLOCK_DEPTH; + int TOTAL_XYZ = COPY_XLEN * COPY_YLEN * COPY_ZLEN; + int TOTAL_XY = COPY_XLEN * COPY_YLEN; + int XY_SLICES = BLOCK_SIZE / TOTAL_XY; + int XY_REM = BLOCK_SIZE % TOTAL_XY; + int X_SLICES = BLOCK_SIZE / COPY_XLEN; + int X_REM = BLOCK_SIZE % COPY_XLEN; + + ////////////////////////////////////////////////////////////////////// + // Block size larger than number of elements, one read + if ((BLOCK_SIZE / TOTAL_XYZ) > 0) + { + DEBUG_PRINT("Chose path XYZ\n"); + + // De-linearize + int ltidx = ltid % COPY_XLEN; + int ltidy = (ltid / COPY_XLEN) % COPY_YLEN; + int ltidz = (ltid / COPY_XLEN) / COPY_YLEN; + + if (ltid < TOTAL_XYZ) + { + smem[ltidx*DST_XSTRIDE + ltidy * DST_YSTRIDE + ltidz * DST_ZSTRIDE] = + *(ptr + ltidx * src_xstride + + src_ystride * ltidy + + src_zstride * ltidz); + } + } + + ////////////////////////////////////////////////////////////////////// + // More than one XY slice + else if ((BLOCK_SIZE / TOTAL_XYZ) == 0 && XY_SLICES > 0 && XY_REM > 0) + { + DEBUG_PRINT("Chose path XY.1\n"); + + // Currently, only use threads in slice + // TODO(later): If contiguous (DST_YSTRIDE == COPY_XLEN), use the rest + int SLICES_PER_ITER = (XY_SLICES == 0 ? 1 : // Compilers are annoying + COPY_ZLEN / XY_SLICES); + int REMAINDER = (XY_SLICES == 0 ? 1 : // Compilers are annoying + COPY_ZLEN % XY_SLICES); + int REMOFF = SLICES_PER_ITER * XY_SLICES; + + if (ltid < (BLOCK_SIZE - XY_REM)) + { + // De-linearize + int ltidx = ltid % COPY_XLEN; + int ltidy = (ltid / COPY_XLEN) % COPY_YLEN; + int ltidz = (ltid / COPY_XLEN) / COPY_YLEN; + + #pragma unroll + for (int i = 0; i < SLICES_PER_ITER; ++i) + { + smem[ltidx * DST_XSTRIDE + ltidy * DST_YSTRIDE + (i*XY_SLICES + ltidz) * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * ltidy + + src_zstride * (i * XY_SLICES + ltidz)); + } + + if (ltidz < REMAINDER) + { + // Read remainder + smem[ltidx * DST_XSTRIDE + ltidy * DST_YSTRIDE + (REMOFF + ltidz) * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * ltidy + + src_zstride * (REMOFF + ltidz)); + } + } + } + + ////////////////////////////////////////////////////////////////////// + // Exactly n*XY slices + else if ((BLOCK_SIZE / TOTAL_XYZ) == 0 && XY_SLICES > 0 && XY_REM == 0) + { + DEBUG_PRINT("Chose path XY.2\n"); + + int SLICES_PER_ITER = (XY_SLICES == 0 ? 1 : // Compilers are annoying + COPY_ZLEN / XY_SLICES); + int REMAINDER = (XY_SLICES == 0 ? 1 : // Compilers are annoying + COPY_ZLEN % XY_SLICES); + int REMOFF = SLICES_PER_ITER * XY_SLICES; + + // De-linearize + int ltidx = ltid % COPY_XLEN; + int ltidy = (ltid / COPY_XLEN) % COPY_YLEN; + int ltidz = (ltid / COPY_XLEN) / COPY_YLEN; + + #pragma unroll + for (int i = 0; i < SLICES_PER_ITER; ++i) + { + smem[ltidx * DST_XSTRIDE + ltidy * DST_YSTRIDE + (i*XY_SLICES + ltidz) * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * ltidy + + src_zstride * (i * XY_SLICES + ltidz)); + } + + if (ltidz < REMAINDER) + { + // Read remainder + smem[ltidx * DST_XSTRIDE + ltidy * DST_YSTRIDE + (REMOFF + ltidz) * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * ltidy + + src_zstride * (REMOFF + ltidz)); + } + } + + ////////////////////////////////////////////////////////////////////// + // More than X row + else if (XY_SLICES == 0 && X_SLICES > 0 && X_REM > 0) + { + DEBUG_PRINT("Chose path X.1\n"); + + // Currently, only use threads in row + // TODO(later): If contiguous (DST_YSTRIDE == COPY_XLEN), use the rest + int ROWS_PER_XY_SLICE = (X_SLICES == 0 ? 1 : // Compilers are annoying + COPY_YLEN / X_SLICES); + int REMAINDER = (X_SLICES == 0 ? 1 : // Compilers are annoying + COPY_YLEN % X_SLICES); + int REMOFF = ROWS_PER_XY_SLICE * X_SLICES; + + if (ltid < (BLOCK_SIZE - X_REM)) + { + // De-linearize + int ltidx = ltid % COPY_XLEN; + int ltidy = ltid / COPY_XLEN; + + #pragma unroll + for (int i = 0; i < COPY_ZLEN; ++i) + { + #pragma unroll + for (int j = 0; j < ROWS_PER_XY_SLICE; ++j) + { + smem[ltidx * DST_XSTRIDE + (j*X_SLICES + ltidy) * DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * (j * X_SLICES + ltidy) + + src_zstride * i); + } + + if (ltidy < REMAINDER) + { + // Read remainder + smem[ltidx * DST_XSTRIDE + (REMOFF + ltidy)* DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * (REMOFF + ltidy) + + src_zstride * i); + } + + } + } + } + + ////////////////////////////////////////////////////////////////////// + // Exactly n*X rows + else if (XY_SLICES == 0 && X_SLICES > 0 && X_REM == 0) + { + DEBUG_PRINT("Chose path X.2\n"); + + int ROWS_PER_XY_SLICE = (X_SLICES == 0 ? 1 : // Compilers are annoying + COPY_YLEN / X_SLICES); + int REMAINDER = (X_SLICES == 0 ? 1 : // Compilers are annoying + COPY_YLEN % X_SLICES); + int REMOFF = ROWS_PER_XY_SLICE * X_SLICES; + + // De-linearize + int ltidx = ltid % COPY_XLEN; + int ltidy = ltid / COPY_XLEN; + + #pragma unroll + for (int i = 0; i < COPY_ZLEN; ++i) + { + #pragma unroll + for (int j = 0; j < ROWS_PER_XY_SLICE; ++j) + { + smem[ltidx * DST_XSTRIDE + (j*X_SLICES + ltidy) * DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * (j * X_SLICES + ltidy) + + src_zstride * i); + } + + if (ltidy < REMAINDER) + { + // Read remainder + smem[ltidx * DST_XSTRIDE + (REMOFF + ltidy) * DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * ltidx + + src_ystride * (REMOFF + ltidy) + + src_zstride * i); + } + + } + } + + ////////////////////////////////////////////////////////////////////// + // Less than one X row + else if (X_SLICES == 0) + { + DEBUG_PRINT("Chose path X.3\n"); + + + int ITERATIONS_PER_ROW = COPY_XLEN / BLOCK_SIZE; + int REMAINDER = COPY_XLEN % BLOCK_SIZE; + int REMOFF = ITERATIONS_PER_ROW * BLOCK_SIZE; + + #pragma unroll + for (int i = 0; i < COPY_ZLEN; ++i) + { + #pragma unroll + for (int j = 0; j < COPY_YLEN; ++j) + { + #pragma unroll + for (int k = 0; k < ITERATIONS_PER_ROW; ++k) + { + smem[(k * BLOCK_SIZE + ltid) * DST_XSTRIDE + j * DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * (k * BLOCK_SIZE + ltid) + + src_ystride * j + + src_zstride * i); + } + + if (ltid < REMAINDER) + { + // Read remainder + smem[(REMOFF + ltid) * DST_ZSTRIDE + j * DST_YSTRIDE + i * DST_ZSTRIDE] = + *(ptr + + src_xstride * (REMOFF + ltid) + + src_ystride * j + + src_zstride * i); + } + } + } + } + + ////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////// + + + if (!ASYNC) + __syncthreads(); + } + + template + static DACE_DFI void GlobalToShared1DDynamic( + const T *ptr, int src_xstride, T *smem, int COPY_XLEN) + { + GlobalToShared3DDynamic( + ptr, 1, 1, src_xstride, smem, 1, 1, COPY_XLEN); + } + + template + static DACE_DFI void GlobalToShared2DDynamic( + const T *ptr, int src_ystride, int src_xstride, + T *smem, int COPY_YLEN, int COPY_XLEN) + { + GlobalToShared3DDynamic( + ptr, 1, src_ystride, src_xstride, smem, 1, COPY_YLEN, COPY_XLEN); + } + + + /* + template + static DACE_DFI void SharedToGlobal1D( + const T *smem, int src_xstride, T *ptr) + { + GlobalToShared3D( + smem, 1, 1, src_xstride, ptr); + } + */ + + template + struct ResetShared + { + static DACE_DFI void Reset(T *smem) { + // Linear thread ID + int ltid = GetLinearTID(); + constexpr int BLOCK_SIZE = BLOCK_WIDTH * BLOCK_HEIGHT * BLOCK_DEPTH; + constexpr int TOTAL = SMEM_TOTAL_ELEMENTS; + constexpr int WRITES = TOTAL / BLOCK_SIZE; + constexpr int REM_WRITES = TOTAL % BLOCK_SIZE; + + #pragma unroll + for (int i = 0; i < WRITES; ++i) { + *(smem + (ltid + i * BLOCK_SIZE) * DST_XSTRIDE) = T(0); + } + + if (REM_WRITES != 0) { + if (ltid < REM_WRITES) + *(smem + (ltid + WRITES * BLOCK_SIZE) * DST_XSTRIDE) = T(0); + } + + if (!ASYNC) + __syncthreads(); + } + }; + + template + struct SharedToGlobal1D + { + template + static DACE_DFI void Accum(const T *smem, int src_xstride, T *ptr, WCR wcr) + { + if (!ASYNC) + __syncthreads(); + + // Linear thread ID + int ltid = GetLinearTID(); + constexpr int BLOCK_SIZE = BLOCK_WIDTH * BLOCK_HEIGHT * BLOCK_DEPTH; + constexpr int TOTAL = COPY_XLEN; + constexpr int WRITES = TOTAL / BLOCK_SIZE; + constexpr int REM_WRITES = TOTAL % BLOCK_SIZE; + + #pragma unroll + for (int i = 0; i < WRITES; ++i) { + wcr_custom::template reduce( + wcr, ptr + (ltid + i * BLOCK_SIZE) * DST_XSTRIDE, + *(smem + (ltid + i * BLOCK_SIZE) * src_xstride)); + } + + if (REM_WRITES != 0) { + if (ltid < REM_WRITES) + wcr_custom::template reduce( + ptr + (ltid + WRITES * BLOCK_SIZE)* DST_XSTRIDE, + *(smem + (ltid + WRITES * BLOCK_SIZE) * src_xstride)); + } + } + + template + static DACE_DFI void Accum(const T *smem, int src_xstride, T *ptr) + { + if (!ASYNC) + __syncthreads(); + + // Linear thread ID + int ltid = GetLinearTID(); + constexpr int BLOCK_SIZE = BLOCK_WIDTH * BLOCK_HEIGHT * BLOCK_DEPTH; + constexpr int TOTAL = COPY_XLEN; + constexpr int WRITES = TOTAL / BLOCK_SIZE; + constexpr int REM_WRITES = TOTAL % BLOCK_SIZE; + + #pragma unroll + for (int i = 0; i < WRITES; ++i) { + wcr_fixed::template reduce_atomic( + ptr + (ltid + i * BLOCK_SIZE) * DST_XSTRIDE, + *(smem + (ltid + i * BLOCK_SIZE) * src_xstride)); + } + + if (REM_WRITES != 0) { + if (ltid < REM_WRITES) + wcr_fixed::template reduce_atomic( + ptr + (ltid + WRITES*BLOCK_SIZE)* DST_XSTRIDE, + *(smem + (ltid + WRITES * BLOCK_SIZE) * src_xstride)); + } + } + }; + + // TODO: Make like SharedToGlobal1D + template + static DACE_DFI void SharedToGlobal2D( + const T *ptr, int src_ystride, int src_xstride, + T *smem) + { + GlobalToShared3D( + ptr, 1, src_ystride, src_xstride, smem); + } + template + static DACE_DFI void SharedToGlobal2DDynamic( + const T *ptr, int src_ystride, int src_xstride, + T *smem, int COPY_YLEN, int COPY_XLEN) + { + GlobalToShared3DDynamic( + ptr, 1, src_ystride, src_xstride, smem, 1, COPY_YLEN, COPY_XLEN); + } + +} // namespace dace + + + + +#endif // __DACE_CUDACOPY_CUH diff --git a/dace/runtime/include/dace/cuda/cudacommon.cuh b/dace/runtime/include/dace/cuda/cudacommon.cuh new file mode 100644 index 0000000000..61aa4623df --- /dev/null +++ b/dace/runtime/include/dace/cuda/cudacommon.cuh @@ -0,0 +1,19 @@ +#ifndef __DACE_CUDACOMMON_CUH +#define __DACE_CUDACOMMON_CUH + +#define DACE_CUDA_CHECK(err) do { \ + cudaError_t errr = (err); \ + if(errr != (cudaError_t)0) \ + { \ + printf("CUDA ERROR at %s:%d, code: %d\n", __FILE__, __LINE__, errr); \ + } \ +} while(0) + +namespace dace { + namespace cuda { + extern cudaStream_t __streams[]; + extern cudaEvent_t __events[]; + } // namespace cuda +} // namespace dace + +#endif // __DACE_CUDACOMMON_CUH diff --git a/dace/runtime/include/dace/cuda/dynmap.cuh b/dace/runtime/include/dace/cuda/dynmap.cuh new file mode 100644 index 0000000000..14330f1425 --- /dev/null +++ b/dace/runtime/include/dace/cuda/dynmap.cuh @@ -0,0 +1,249 @@ +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the names of the copyright holders nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +// Adapted from "Groute: An Asynchronous Multi-GPU Programming Framework" +// http://www.github.com/groute/groute + + +#ifndef __DACE_DYNMAP_CUH +#define __DACE_DYNMAP_CUH + +#include +#include +#include +#include +#include +#include + +#include "../../../../external/cub/cub/util_ptx.cuh" + +#define __FULL_MASK 0xffffffff + +namespace dace { + /** + * A map (usually dynamically sized) that can be rescheduled across a + * threadblock + **/ + template + struct DynamicMap + { + template struct warp_np { + volatile index_type owner[WARPS_PER_TB]; + volatile index_type start[WARPS_PER_TB]; + volatile index_type size[WARPS_PER_TB]; + volatile index_type src[WARPS_PER_TB]; + }; + + struct tb_np { + index_type owner; + index_type start; + index_type size; + index_type src; + }; + + struct empty_np { + }; + + + template + union np_shared { + // for scans + ts_type temp_storage; + + // for tb-level np + TTB tb; + + // for warp-level np + TWP warp; + + // fine-grained schedule (unused) + //TFG fg; + }; + + /* + * @brief A structure representing a scheduled chunk of work + */ + struct np_local + { + index_type size; // work size + index_type start; // work start + index_type src; // work source thread / metadata + }; + + + template + __device__ __forceinline__ static void schedule(index_type local_start, index_type local_end, index_type local_src, Functor&& work) + { + const int WP_SIZE = CUB_PTX_WARP_THREADS; + const int TB_SIZE = BLOCK_SIZE; + + const int NP_WP_CROSSOVER = CUB_PTX_WARP_THREADS; + const int NP_TB_CROSSOVER = blockDim.x; + + typedef union std::conditional, + np_shared>>::type np_shared_type; + + __shared__ np_shared_type np_shared; + + index_type local_size = local_end - local_start; + + if (threadIdx.x == 0) + { + np_shared.tb.owner = TB_SIZE + 1; + } + + __syncthreads(); + + // + // First scheduler: processing high-degree work items using the entire block + // + while (true) + { + if (local_size >= NP_TB_CROSSOVER) + { + // 'Elect' one owner for the entire thread block + np_shared.tb.owner = threadIdx.x; + } + + __syncthreads(); + + if (np_shared.tb.owner == TB_SIZE + 1) + { + // No owner was elected, i.e. no high-degree work items remain + + // No need to sync threads before moving on to WP scheduler + // because it does not use shared memory + if (!WARP_INTRINSICS) + __syncthreads(); // Necessary do to the shared memory union used by both TB and WP schedulers + break; + } + + if (np_shared.tb.owner == threadIdx.x) + { + // This thread is the owner + np_shared.tb.start = local_start; + np_shared.tb.size = local_size; + np_shared.tb.src = local_src; + + // Mark this work-item as processed for future schedulers + local_start = 0; + local_size = 0; + } + + __syncthreads(); + + index_type start = np_shared.tb.start; + index_type size = np_shared.tb.size; + index_type src = np_shared.tb.src; + + if (np_shared.tb.owner == threadIdx.x) + { + np_shared.tb.owner = TB_SIZE + 1; + } + + // Use all threads in thread block to execute individual work + for (int ii = threadIdx.x; ii < size; ii += TB_SIZE) + { + work(start + ii, src); + } + + __syncthreads(); + } + + // + // Second scheduler: tackle medium-degree work items using the warp + // + const int warp_id = cub::WarpId(); + const int lane_id = cub::LaneId(); + + while (__any_sync(__FULL_MASK, local_size >= NP_WP_CROSSOVER)) + { + index_type start, size, src; + if (WARP_INTRINSICS) + { + // Compete for work scheduling + unsigned int mask = __ballot_sync(__FULL_MASK, local_size >= NP_WP_CROSSOVER ? 1 : 0); + // Select a deterministic winner + int leader = __ffs(mask) - 1; + + // Broadcast data from the leader + start = cub::ShuffleIndex(local_start, leader, mask); + size = cub::ShuffleIndex(local_size, leader, mask); + src = cub::ShuffleIndex(local_src, leader, mask); + + if (leader == lane_id) + { + // Mark this work-item as processed + local_start = 0; + local_size = 0; + } + } + else + { + // In order for this to compile, it should be refactored to another function + /* + if (local_size >= NP_WP_CROSSOVER) + { + // Again, race to select an owner for warp + np_shared.warp.owner[warp_id] = lane_id; + } + if (np_shared.warp.owner[warp_id] == lane_id) + { + // This thread is owner + np_shared.warp.start[warp_id] = local_start; + np_shared.warp.size[warp_id] = local_size; + + // Mark this work-item as processed + local_start = 0; + local_size = 0; + } + start = np_shared.warp.start[warp_id]; + size = np_shared.warp.size[warp_id]; + */ + } + + for (int ii = lane_id; ii < size; ii += WP_SIZE) + { + work(start + ii, src); + } + } + + __syncthreads(); + + // + // Third scheduler: tackle all work-items with size < 32 serially + // + // It is possible to disable this scheduler by setting NP_WP_CROSSOVER to 0 + + for (int ii = 0; ii < local_size; ii++) + { + work(local_start + ii, local_src); + } + } + }; + +} // namespace dace + +#endif // __DACE_DYNMAP_CUH diff --git a/dace/runtime/include/dace/cuda/stream.cuh b/dace/runtime/include/dace/cuda/stream.cuh new file mode 100644 index 0000000000..84314ecf9a --- /dev/null +++ b/dace/runtime/include/dace/cuda/stream.cuh @@ -0,0 +1,245 @@ +#ifndef __DACE_STREAM_CUH +#define __DACE_STREAM_CUH + +#include +#include +#include +#include +#include +#include // Used for the in-memory ctor call in the move assignment operator below + +#include +#include + +#include "../../../../external/cub/cub/util_ptx.cuh" +#include "../../../../external/cub/cub/warp/warp_reduce.cuh" +#include "../../../../external/cub/cub/warp/warp_scan.cuh" + +#include "cudacommon.cuh" + +namespace dace { + // Adapted from https://devblogs.nvidia.com/cuda-pro-tip-optimized-filtering-warp-aggregated-atomics/ + __inline__ __device__ uint32_t atomicAggInc(uint32_t *ctr) { + auto g = cooperative_groups::coalesced_threads(); + uint32_t warp_res; + int rank = g.thread_rank(); + if (rank == 0) + warp_res = atomicAdd(ctr, g.size()); + return g.shfl(warp_res, 0) + rank; + } + + __inline__ __device__ uint32_t atomicAggDec(uint32_t *ctr) { + auto g = cooperative_groups::coalesced_threads(); + uint32_t warp_res; + int rank = g.thread_rank(); + if (rank == 0) + warp_res = atomicAdd(ctr, -g.size()); + return g.shfl(warp_res, 0) + rank; + } + + /* + __inline__ __device__ uint32_t warpReduceSum(uint32_t val) { + for (int offset = CUB_PTX_WARP_THREADS / 2; offset > 0; offset /= 2) + val += __shfl_down(val, offset); + return val; + } + */ + + // + // Queue classes (device): + // + + /* + * @brief A device-level MPMC Queue + */ + template + class GPUStream + { + public: + T* m_data; + uint32_t *m_start, *m_end, *m_pending; + uint32_t m_capacity_mask; + + __host__ GPUStream() : m_data(nullptr), m_start(nullptr), m_end(nullptr), + m_pending(nullptr), m_capacity_mask(0) {} + __host__ __device__ GPUStream(T* data, uint32_t capacity, + uint32_t *start, uint32_t *end, + uint32_t *pending) : + m_data(data), m_start(start), m_end(end), m_pending(pending), + m_capacity_mask(IS_POWEROFTWO ? (capacity - 1) : capacity) + { + if (IS_POWEROFTWO) { + assert((capacity - 1 & capacity) == 0); // Must be a power of two for handling circular overflow correctly + } + } + + __device__ __forceinline__ void reset() const + { + *m_start = 0; + *m_end = 0; + *m_pending = 0; + } + + __device__ __forceinline__ T pop() + { + uint32_t allocation = atomicAggInc(m_start); + return m_data[get_addr(allocation)]; + } + + __device__ __forceinline__ T *leader_pop(uint32_t count) { + uint32_t current = *m_start; + T *result = m_data + get_addr(current); + *m_start += count; + return result; + } + + + __device__ __forceinline__ uint32_t get_addr(const uint32_t& i) const { + if (IS_POWEROFTWO) + return i & m_capacity_mask; + else + return i % m_capacity_mask; + } + + __device__ __forceinline__ void push(const T& item) + { + uint32_t allocation = atomicAggInc(m_pending); + m_data[get_addr(allocation)] = item; + } + + /* + __device__ __forceinline__ void push(T *items, int count) + { + // Perform a warp-wide scan to get thread offsets + typedef cub::WarpScan WarpScan; + __shared__ typename WarpScan::TempStorage temp_storage[4]; + int offset; + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).ExclusiveSum(count, offset); + + // Atomic-add the total count once per warp + uint32_t addr; + if (threadIdx.x & 31 == 31) // Last thread + addr = atomicAdd(m_pending, offset + count); + // Broadcast starting address + addr = cub::ShuffleIndex(addr, 31, 0xffffffff); + + // Copy data from each thread + for(int i = 0; i < count; ++i) + m_data[get_addr(addr + offset + i)] = items[i]; + } + */ + + __device__ __forceinline__ void prepend(const T& item) + { + uint32_t allocation = atomicAggDec(m_start) - 1; + m_data[get_addr(allocation)] = item; + } + + __device__ __forceinline__ T read(uint32_t i) const + { + return m_data[get_addr(*m_start + i)]; + } + + __device__ __forceinline__ uint32_t count() const + { + return *m_end - *m_start; + } + + // Returns the 'count' of pending items and commits + __device__ __forceinline__ uint32_t commit_pending() const + { + uint32_t count = *m_pending - *m_end; + + // Sync end with pending, this makes the pushed items visible to the consumer + *m_end = *m_pending; + return count; + } + + __device__ __forceinline__ uint32_t get_start() const + { + return *m_start; + } + + __device__ __forceinline__ uint32_t get_start_delta(uint32_t prev_start) const + { + return prev_start - *m_start; + } + }; + + //////////////////////////////////////////////////////////// + // Host controllers for GPU streams + + template + __global__ void ResetGPUStream_kernel(GPUStream stream) + { + stream.reset(); + } + + template + void ResetGPUStream(GPUStream& stream) + { + void *args_reset[1] = { &stream }; + DACE_CUDA_CHECK(cudaLaunchKernel((void *)&ResetGPUStream_kernel, + dim3(1, 1, 1), dim3(1, 1, 1), + args_reset, 0, (cudaStream_t)0)); + } + + template + __global__ void PushToGPUStream_kernel(GPUStream stream, T item) + { + stream.push(item); + stream.commit_pending(); + } + + template + void PushToGPUStream(GPUStream& stream, const T& item) + { + void *args_push[2] = { &stream, &item }; + DACE_CUDA_CHECK(cudaLaunchKernel((void *)&PushToGPUStream_kernel, + dim3(1, 1, 1), dim3(1, 1, 1), + args_push, 0, (cudaStream_t)0)); + } + + //////////////////////////////////////////////////////////// + // Host memory management for GPU streams + + + template + GPUStream AllocGPUArrayStreamView(T *ptr, uint32_t capacity) + { + uint32_t *gStart, *gEnd, *gPending; + DACE_CUDA_CHECK(cudaMalloc(&gStart, sizeof(uint32_t))); + DACE_CUDA_CHECK(cudaMalloc(&gEnd, sizeof(uint32_t))); + DACE_CUDA_CHECK(cudaMalloc(&gPending, sizeof(uint32_t))); + DACE_CUDA_CHECK(cudaMemsetAsync(gStart, 0, sizeof(uint32_t))); + DACE_CUDA_CHECK(cudaMemsetAsync(gEnd, 0, sizeof(uint32_t))); + DACE_CUDA_CHECK(cudaMemsetAsync(gPending, 0, sizeof(uint32_t))); + return GPUStream(ptr, capacity, gStart, gEnd, gPending); + } + + template + GPUStream AllocGPUStream(uint32_t capacity) + { + T *gData; + DACE_CUDA_CHECK(cudaMalloc(&gData, capacity * sizeof(T))); + return AllocGPUArrayStreamView(gData, capacity); + } + + template + void FreeGPUArrayStreamView(GPUStream& stream) + { + DACE_CUDA_CHECK(cudaFree(stream.m_start)); + DACE_CUDA_CHECK(cudaFree(stream.m_end)); + DACE_CUDA_CHECK(cudaFree(stream.m_pending)); + } + + template + void FreeGPUStream(GPUStream& stream) + { + FreeGPUArrayStreamView(stream); + DACE_CUDA_CHECK(cudaFree(stream.m_data)); + } + +} // namespace dace +#endif // __DACE_STREAM_CUH \ No newline at end of file diff --git a/dace/runtime/include/dace/cuda/vectype.cuh b/dace/runtime/include/dace/cuda/vectype.cuh new file mode 100644 index 0000000000..f8acaa82e1 --- /dev/null +++ b/dace/runtime/include/dace/cuda/vectype.cuh @@ -0,0 +1,326 @@ +//////////////////////////////////////////////////////////////////////// +// Define some operators on vector types + +#define DEFINE_EXTTYPE1(T, NAME) \ + struct exttype_##T##_##1 : NAME##1 { \ + DACE_HDFI exttype_##T##_##1 operator*(const exttype_##T##_##1 &other) const { \ + exttype_##T##_##1 result; \ + result.x = other.x * x; \ + return result; \ + } \ + DACE_HDFI exttype_##T##_##1 operator+(const exttype_##T##_##1 &other) const { \ + exttype_##T##_##1 result; \ + result.x = other.x + x; \ + return result; \ + } \ + DACE_HDFI exttype_##T##_##1 operator-(const exttype_##T##_##1 &other) const { \ + exttype_##T##_##1 result; \ + result.x = x - other.x; \ + return result; \ + } \ + DACE_HDFI exttype_##T##_##1 operator/(const exttype_##T##_##1 &other) const { \ + exttype_##T##_##1 result; \ + result.x = x / other.x; \ + return result; \ + } \ + template \ + DACE_HDFI exttype_##T##_##1 operator*(const U &other) const { \ + exttype_##T##_##1 result; \ + result.x = other * x; \ + return result; \ + } \ + template \ + DACE_HDFI exttype_##T##_##1 operator+(const U &other) const { \ + exttype_##T##_##1 result; \ + result.x = other + x; \ + return result; \ + } \ + template \ + DACE_HDFI exttype_##T##_##1 operator-(const U &other) const { \ + exttype_##T##_##1 result; \ + result.x = x - other; \ + return result; \ + } \ + template \ + DACE_HDFI exttype_##T##_##1 operator/(const U &other) const { \ + exttype_##T##_##1 result; \ + result.x = x / other; \ + return result; \ + } \ + template \ + DACE_HDFI T operator[](const U &index) const { \ + return x; \ + } \ + }; +#define DEFINE_EXTTYPE2(T, NAME) \ + struct exttype_##T##_##2 : NAME##2 { \ + DACE_HDFI exttype_##T##_##2 operator*(const exttype_##T##_##2 &other) const { \ + exttype_##T##_##2 result; \ + result.x = other.x * x; \ + result.y = other.y * y; \ + return result; \ + } \ + DACE_HDFI exttype_##T##_##2 operator+(const exttype_##T##_##2 &other) const { \ + exttype_##T##_##2 result; \ + result.x = other.x + x; \ + result.y = other.y + y; \ + return result; \ + } \ + DACE_HDFI exttype_##T##_##2 operator-(const exttype_##T##_##2 &other) const { \ + exttype_##T##_##2 result; \ + result.x = x - other.x; \ + result.y = y - other.y; \ + return result; \ + } \ + DACE_HDFI exttype_##T##_##2 operator/(const exttype_##T##_##2 &other) const { \ + exttype_##T##_##2 result; \ + result.x = x / other.x; \ + result.y = y / other.y; \ + return result; \ + } \ + template