Halide  14.0.0
Halide compiler and libraries
JITModule.h
Go to the documentation of this file.
1 #ifndef HALIDE_JIT_MODULE_H
2 #define HALIDE_JIT_MODULE_H
3 
4 /** \file
5  * Defines the struct representing lifetime and dependencies of
6  * a JIT compiled halide pipeline
7  */
8 
9 #include <map>
10 #include <memory>
11 
12 #include "IntrusivePtr.h"
13 #include "Type.h"
14 #include "runtime/HalideRuntime.h"
15 
16 namespace llvm {
17 class Module;
18 }
19 
20 namespace Halide {
21 
22 struct ExternCFunction;
23 struct JITExtern;
24 struct Target;
25 class Module;
26 
27 struct JITUserContext;
28 
29 /** A set of custom overrides of runtime functions. These only apply
30  * when JIT-compiling code. If you are doing AOT compilation, see
31  * HalideRuntime.h for instructions on how to replace runtime
32  * functions. */
33 struct JITHandlers {
34  /** Set the function called to print messages from the runtime. */
35  void (*custom_print)(JITUserContext *, const char *){nullptr};
36 
37  /** A custom malloc and free for halide to use. Malloc should
38  * return 32-byte aligned chunks of memory, and it should be safe
39  * for Halide to read slightly out of bounds (up to 8 bytes before
40  * the start or beyond the end). */
41  // @{
42  void *(*custom_malloc)(JITUserContext *, size_t){nullptr};
43  void (*custom_free)(JITUserContext *, void *){nullptr};
44  // @}
45 
46  /** A custom task handler to be called by the parallel for
47  * loop. It is useful to set this if you want to do some
48  * additional bookkeeping at the granularity of parallel
49  * tasks. The default implementation does this:
50  \code
51  extern "C" int halide_do_task(JITUserContext *user_context,
52  int (*f)(void *, int, uint8_t *),
53  int idx, uint8_t *state) {
54  return f(user_context, idx, state);
55  }
56  \endcode
57  *
58  * If you're trying to use a custom parallel runtime, you probably
59  * don't want to call this. See instead custom_do_par_for.
60  */
61  int (*custom_do_task)(JITUserContext *, int (*)(JITUserContext *, int, uint8_t *), int, uint8_t *){nullptr};
62 
63  /** A custom parallel for loop launcher. Useful if your app
64  * already manages a thread pool. The default implementation is
65  * equivalent to this:
66  \code
67  extern "C" int halide_do_par_for(JITUserContext *user_context,
68  int (*f)(void *, int, uint8_t *),
69  int min, int extent, uint8_t *state) {
70  int exit_status = 0;
71  parallel for (int idx = min; idx < min+extent; idx++) {
72  int job_status = halide_do_task(user_context, f, idx, state);
73  if (job_status) exit_status = job_status;
74  }
75  return exit_status;
76  }
77  \endcode
78  *
79  * However, notwithstanding the above example code, if one task
80  * fails, we may skip over other tasks, and if two tasks return
81  * different error codes, we may select one arbitrarily to return.
82  */
83  int (*custom_do_par_for)(JITUserContext *, int (*)(JITUserContext *, int, uint8_t *), int, int, uint8_t *){nullptr};
84 
85  /** The error handler function that be called in the case of
86  * runtime errors during halide pipelines. */
87  void (*custom_error)(JITUserContext *, const char *){nullptr};
88 
89  /** A custom routine to call when tracing is enabled. Call this
90  * on the output Func of your pipeline. This then sets custom
91  * routines for the entire pipeline, not just calls to this
92  * Func. */
94 
95  /** A method to use for Halide to resolve symbol names dynamically
96  * in the calling process or library from within the Halide
97  * runtime. Equivalent to dlsym with a null first argument. */
98  void *(*custom_get_symbol)(const char *name){nullptr};
99 
100  /** A method to use for Halide to dynamically load libraries from
101  * within the runtime. Equivalent to dlopen. Returns a handle to
102  * the opened library. */
103  void *(*custom_load_library)(const char *name){nullptr};
104 
105  /** A method to use for Halide to dynamically find a symbol within
106  * an opened library. Equivalent to dlsym. Takes a handle
107  * returned by custom_load_library as the first argument. */
108  void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr};
109 
110  /** A custom method for the Halide runtime acquire a cuda
111  * context. The cuda context is treated as a void * to avoid a
112  * dependence on the cuda headers. If the create argument is set
113  * to true, a context should be created if one does not already
114  * exist. */
115  int32_t (*custom_cuda_acquire_context)(JITUserContext *user_context, void **cuda_context_ptr, bool create){nullptr};
116 
117  /** The Halide runtime calls this when it is done with a cuda
118  * context. The default implementation does nothing. */
120 
121  /** A custom method for the Halide runtime to acquire a cuda
122  * stream to use. The cuda context and stream are both modelled
123  * as a void *, to avoid a dependence on the cuda headers. */
124  int32_t (*custom_cuda_get_stream)(JITUserContext *user_context, void *cuda_context, void **stream_ptr){nullptr};
125 };
126 
127 namespace Internal {
128 struct JITErrorBuffer;
129 }
130 
131 /** A context to be passed to Pipeline::realize. Inherit from this to
132  * pass your own custom context object. Modify the handlers field to
133  * override runtime functions per-call to realize. */
135  Internal::JITErrorBuffer *error_buffer{nullptr};
137 };
138 
139 namespace Internal {
140 
141 class JITModuleContents;
142 struct LoweredFunc;
143 
144 struct JITModule {
146 
147  struct Symbol {
148  void *address = nullptr;
149  Symbol() = default;
150  explicit Symbol(void *address)
151  : address(address) {
152  }
153  };
154 
156  JITModule(const Module &m, const LoweredFunc &fn,
157  const std::vector<JITModule> &dependencies = std::vector<JITModule>());
158 
159  /** Take a list of JITExterns and generate trampoline functions
160  * which can be called dynamically via a function pointer that
161  * takes an array of void *'s for each argument and the return
162  * value.
163  */
165  const std::map<std::string, JITExtern> &externs,
166  const std::string &suffix,
167  const std::vector<JITModule> &deps);
168 
169  /** The exports map of a JITModule contains all symbols which are
170  * available to other JITModules which depend on this one. For
171  * runtime modules, this is all of the symbols exported from the
172  * runtime. For a JITted Func, it generally only contains the main
173  * result Func of the compilation, which takes its name directly
174  * from the Func declaration. One can also make a module which
175  * contains no code itself but is just an exports maps providing
176  * arbitrary pointers to functions or global variables to JITted
177  * code. */
178  const std::map<std::string, Symbol> &exports() const;
179 
180  /** A pointer to the raw halide function. Its true type depends
181  * on the Argument vector passed to CodeGen_LLVM::compile. Image
182  * parameters become (halide_buffer_t *), and scalar parameters become
183  * pointers to the appropriate values. The final argument is a
184  * pointer to the halide_buffer_t defining the output. This will be nullptr for
185  * a JITModule which has not yet been compiled or one that is not
186  * a Halide Func compilation at all. */
187  void *main_function() const;
188 
189  /** Returns the Symbol structure for the routine documented in
190  * main_function. Returning a Symbol allows access to the LLVM
191  * type as well as the address. The address and type will be nullptr
192  * if the module has not been compiled. */
194 
195  /** Returns the Symbol structure for the argv wrapper routine
196  * corresponding to the entrypoint. The argv wrapper is callable
197  * via an array of void * pointers to the arguments for the
198  * call. Returning a Symbol allows access to the LLVM type as well
199  * as the address. The address and type will be nullptr if the module
200  * has not been compiled. */
202 
203  /** A slightly more type-safe wrapper around the raw halide
204  * module. Takes it arguments as an array of pointers that
205  * correspond to the arguments to \ref main_function . This will
206  * be nullptr for a JITModule which has not yet been compiled or one
207  * that is not a Halide Func compilation at all. */
208  // @{
209  typedef int (*argv_wrapper)(const void **args);
211  // @}
212 
213  /** Add another JITModule to the dependency chain. Dependencies
214  * are searched to resolve symbols not found in the current
215  * compilation unit while JITting. */
217  /** Registers a single Symbol as available to modules which depend
218  * on this one. The Symbol structure provides both the address and
219  * the LLVM type for the function, which allows type safe linkage of
220  * extenal routines. */
221  void add_symbol_for_export(const std::string &name, const Symbol &extern_symbol);
222  /** Registers a single function as available to modules which
223  * depend on this one. This routine converts the ExternSignature
224  * info into an LLVM type, which allows type safe linkage of
225  * external routines. */
226  void add_extern_for_export(const std::string &name,
227  const ExternCFunction &extern_c_function);
228 
229  /** Look up a symbol by name in this module or its dependencies. */
230  Symbol find_symbol_by_name(const std::string &) const;
231 
232  /** Take an llvm module and compile it. The requested exports will
233  be available via the exports method. */
234  void compile_module(std::unique_ptr<llvm::Module> mod,
235  const std::string &function_name, const Target &target,
236  const std::vector<JITModule> &dependencies = std::vector<JITModule>(),
237  const std::vector<std::string> &requested_exports = std::vector<std::string>());
238 
239  /** See JITSharedRuntime::memoization_cache_set_size */
241 
242  /** See JITSharedRuntime::memoization_cache_evict */
243  void memoization_cache_evict(uint64_t eviction_key) const;
244 
245  /** See JITSharedRuntime::reuse_device_allocations */
246  void reuse_device_allocations(bool) const;
247 
248  /** Return true if compile_module has been called on this module. */
249  bool compiled() const;
250 };
251 
253 public:
254  // Note only the first llvm::Module passed in here is used. The same shared runtime is used for all JIT.
255  static std::vector<JITModule> get(llvm::Module *m, const Target &target, bool create = true);
256  static void populate_jit_handlers(JITUserContext *jit_user_context, const JITHandlers &handlers);
258 
259  /** Set the maximum number of bytes used by memoization caching.
260  * If you are compiling statically, you should include HalideRuntime.h
261  * and call halide_memoization_cache_set_size() instead.
262  */
264 
265  /** Evict all cache entries that were tagged with the given
266  * eviction_key in the memoize scheduling directive. If you are
267  * compiling statically, you should include HalideRuntime.h and
268  * call halide_memoization_cache_evict() instead.
269  */
270  static void memoization_cache_evict(uint64_t eviction_key);
271 
272  /** Set whether or not Halide may hold onto and reuse device
273  * allocations to avoid calling expensive device API allocation
274  * functions. If you are compiling statically, you should include
275  * HalideRuntime.h and call halide_reuse_device_allocations
276  * instead. */
277  static void reuse_device_allocations(bool);
278 
279  static void release_all();
280 };
281 
282 void *get_symbol_address(const char *s);
283 
284 } // namespace Internal
285 } // namespace Halide
286 
287 #endif
This file declares the routines used by Halide internally in its runtime.
Support classes for reference-counting via intrusive shared pointers.
Defines halide types.
static void memoization_cache_evict(uint64_t eviction_key)
Evict all cache entries that were tagged with the given eviction_key in the memoize scheduling direct...
static void memoization_cache_set_size(int64_t size)
Set the maximum number of bytes used by memoization caching.
static JITHandlers set_default_handlers(const JITHandlers &handlers)
static std::vector< JITModule > get(llvm::Module *m, const Target &target, bool create=true)
static void reuse_device_allocations(bool)
Set whether or not Halide may hold onto and reuse device allocations to avoid calling expensive devic...
static void populate_jit_handlers(JITUserContext *jit_user_context, const JITHandlers &handlers)
A halide module.
Definition: Module.h:172
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1066
void * get_symbol_address(const char *s)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
__SIZE_TYPE__ size_t
void memoization_cache_evict(uint64_t eviction_key) const
See JITSharedRuntime::memoization_cache_evict.
void memoization_cache_set_size(int64_t size) const
See JITSharedRuntime::memoization_cache_set_size.
int(* argv_wrapper)(const void **args)
A slightly more type-safe wrapper around the raw halide module.
Definition: JITModule.h:209
void add_symbol_for_export(const std::string &name, const Symbol &extern_symbol)
Registers a single Symbol as available to modules which depend on this one.
void compile_module(std::unique_ptr< llvm::Module > mod, const std::string &function_name, const Target &target, const std::vector< JITModule > &dependencies=std::vector< JITModule >(), const std::vector< std::string > &requested_exports=std::vector< std::string >())
Take an llvm module and compile it.
void add_extern_for_export(const std::string &name, const ExternCFunction &extern_c_function)
Registers a single function as available to modules which depend on this one.
void reuse_device_allocations(bool) const
See JITSharedRuntime::reuse_device_allocations.
void add_dependency(JITModule &dep)
Add another JITModule to the dependency chain.
Symbol find_symbol_by_name(const std::string &) const
Look up a symbol by name in this module or its dependencies.
static JITModule make_trampolines_module(const Target &target, const std::map< std::string, JITExtern > &externs, const std::string &suffix, const std::vector< JITModule > &deps)
Take a list of JITExterns and generate trampoline functions which can be called dynamically via a fun...
Symbol argv_entrypoint_symbol() const
Returns the Symbol structure for the argv wrapper routine corresponding to the entrypoint.
bool compiled() const
Return true if compile_module has been called on this module.
const std::map< std::string, Symbol > & exports() const
The exports map of a JITModule contains all symbols which are available to other JITModules which dep...
Symbol entrypoint_symbol() const
Returns the Symbol structure for the routine documented in main_function.
void * main_function() const
A pointer to the raw halide function.
argv_wrapper argv_function() const
IntrusivePtr< JITModuleContents > jit_module
Definition: JITModule.h:145
JITModule(const Module &m, const LoweredFunc &fn, const std::vector< JITModule > &dependencies=std::vector< JITModule >())
Definition of a lowered function.
Definition: Module.h:133
A set of custom overrides of runtime functions.
Definition: JITModule.h:33
int(* custom_do_par_for)(JITUserContext *, int(*)(JITUserContext *, int, uint8_t *), int, int, uint8_t *)
A custom parallel for loop launcher.
Definition: JITModule.h:83
int32_t(* custom_cuda_acquire_context)(JITUserContext *user_context, void **cuda_context_ptr, bool create)
A custom method for the Halide runtime acquire a cuda context.
Definition: JITModule.h:115
int32_t(* custom_cuda_get_stream)(JITUserContext *user_context, void *cuda_context, void **stream_ptr)
A custom method for the Halide runtime to acquire a cuda stream to use.
Definition: JITModule.h:124
void(* custom_error)(JITUserContext *, const char *)
The error handler function that be called in the case of runtime errors during halide pipelines.
Definition: JITModule.h:87
int32_t(* custom_cuda_release_context)(JITUserContext *user_context)
The Halide runtime calls this when it is done with a cuda context.
Definition: JITModule.h:119
void(* custom_free)(JITUserContext *, void *)
Definition: JITModule.h:43
int32_t(* custom_trace)(JITUserContext *, const halide_trace_event_t *)
A custom routine to call when tracing is enabled.
Definition: JITModule.h:93
int(* custom_do_task)(JITUserContext *, int(*)(JITUserContext *, int, uint8_t *), int, uint8_t *)
A custom task handler to be called by the parallel for loop.
Definition: JITModule.h:61
void(* custom_print)(JITUserContext *, const char *)
Set the function called to print messages from the runtime.
Definition: JITModule.h:35
A context to be passed to Pipeline::realize.
Definition: JITModule.h:134
JITHandlers handlers
Definition: JITModule.h:136
Internal::JITErrorBuffer * error_buffer
Definition: JITModule.h:135
A struct representing a target machine and os to generate code for.
Definition: Target.h:19