Open3D (C++ API)  0.15.1
TrilinearDevoxelizeKernel.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// The MIT License (MIT)
5//
6// Copyright (c) 2018-2021 www.open3d.org
7//
8// Permission is hereby granted, free of charge, to any person obtaining a copy
9// of this software and associated documentation files (the "Software"), to deal
10// in the Software without restriction, including without limitation the rights
11// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12// copies of the Software, and to permit persons to whom the Software is
13// furnished to do so, subject to the following conditions:
14//
15// The above copyright notice and this permission notice shall be included in
16// all copies or substantial portions of the Software.
17//
18// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24// IN THE SOFTWARE.
25// ----------------------------------------------------------------------------
26
27#pragma once
28
29#include "../TensorFlowHelper.h"
30#include "tensorflow/core/framework/op.h"
31#include "tensorflow/core/framework/op_kernel.h"
32#include "tensorflow/core/lib/core/errors.h"
33
34class TrilinearDevoxelizeOpKernel : public tensorflow::OpKernel {
35public:
37 tensorflow::OpKernelConstruction* context)
38 : tensorflow::OpKernel(context) {
39 using namespace tensorflow;
40 OP_REQUIRES_OK(context, context->GetAttr("resolution", &r));
41 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training));
42 OP_REQUIRES(context, r > 0,
43 errors::InvalidArgument(
44 "TrilinearDevoxelize expects positive resolution"));
45 }
46
47 void Compute(tensorflow::OpKernelContext* context) override {
48 using namespace tensorflow;
49 const Tensor& coords = context->input(0);
50 OP_REQUIRES(
51 context, coords.dims() == 3 && coords.shape().dim_size(1) == 3,
52 errors::InvalidArgument("TrilinearDevoxelize expects "
53 "(batch_size, 3, N) coordinate shape"));
54 const Tensor& feat = context->input(1);
55 OP_REQUIRES(context, feat.dims() == 5,
56 errors::InvalidArgument("TrilinearDevoxelize expects "
57 "5 dimensions for features"));
58
59 int batch_size = coords.shape().dim_size(0);
60 int num_points = coords.shape().dim_size(2);
61 int feat_dim = feat.shape().dim_size(1);
62
63 auto coords_flat = coords.flat<float>();
64 auto feat_flat = feat.flat<float>();
65
66 const float* inp_coords = &(coords_flat(0));
67 const float* inp_feat = &(feat_flat(0));
68
69 Tensor* out_tensor_0;
70 OP_REQUIRES_OK(context,
71 context->allocate_output(
72 0, TensorShape{batch_size, feat_dim, num_points},
73 &out_tensor_0));
74 Tensor* out_tensor_1;
75 OP_REQUIRES_OK(context,
76 context->allocate_output(
77 1, TensorShape{batch_size, 8, num_points},
78 &out_tensor_1));
79 Tensor* out_tensor_2;
80 OP_REQUIRES_OK(context,
81 context->allocate_output(
82 2, TensorShape{batch_size, 8, num_points},
83 &out_tensor_2));
84 auto flat_0 = out_tensor_0->flat<float>();
85 auto flat_1 = out_tensor_1->flat<int>();
86 auto flat_2 = out_tensor_2->flat<float>();
87
88 float* out_0 = &(flat_0(0));
89 int* out_1 = &(flat_1(0));
90 float* out_2 = &(flat_2(0));
91
92 if (is_training)
93 Kernel(context, batch_size, feat_dim, num_points, r, r * r,
94 r * r * r, true, inp_coords, inp_feat, out_1, out_2, out_0);
95 else
96 Kernel(context, batch_size, feat_dim, num_points, r, r * r,
97 r * r * r, false, inp_coords, inp_feat, out_1, out_2, out_0);
98 }
99
100 virtual void Kernel(tensorflow::OpKernelContext* context,
101 int b,
102 int c,
103 int n,
104 int r,
105 int r2,
106 int r3,
107 bool training,
108 const float* coords,
109 const float* feat,
110 int* inds,
111 float* wgts,
112 float* outs) = 0;
113
114protected:
115 int r;
117};
118
119class TrilinearDevoxelizeGradOpKernel : public tensorflow::OpKernel {
120public:
122 tensorflow::OpKernelConstruction* context)
123 : tensorflow::OpKernel(context) {
124 using namespace tensorflow;
125 OP_REQUIRES_OK(context, context->GetAttr("resolution", &r));
126 OP_REQUIRES(
127 context, r > 0,
128 errors::InvalidArgument(
129 "TrilinearDevoxelizeGrad expects positive resolution"));
130 }
131
132 void Compute(tensorflow::OpKernelContext* context) override {
133 using namespace tensorflow;
134 const Tensor& grad_y = context->input(0);
135 OP_REQUIRES(
136 context, grad_y.dims() == 3,
137 errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
138 "(batch_size, C, N) gradient shape"));
139 const Tensor& inds = context->input(1);
140 OP_REQUIRES(
141 context, inds.dims() == 3 && inds.shape().dim_size(1) == 8,
142 errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
143 "(batch_size, 8, N) indices shape"));
144 const Tensor& wgts = context->input(2);
145 OP_REQUIRES(
146 context, wgts.dims() == 3 && wgts.shape().dim_size(1) == 8,
147 errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
148 "(batch_size, 8, N) weights shape"));
149
150 int batch_size = grad_y.shape().dim_size(0);
151 int num_points = grad_y.shape().dim_size(2);
152 int feat_dim = grad_y.shape().dim_size(1);
153
154 auto grad_y_flat = grad_y.flat<float>();
155 auto inds_flat = inds.flat<int>();
156 auto wgts_flat = wgts.flat<float>();
157
158 const float* inp_grad_y = &(grad_y_flat(0));
159 const int* inp_inds = &(inds_flat(0));
160 const float* inp_wgts = &(wgts_flat(0));
161
162 Tensor* out_tensor;
163 OP_REQUIRES_OK(context,
164 context->allocate_output(
165 0, TensorShape{batch_size, feat_dim, r, r, r},
166 &out_tensor));
167 auto flat_tensor = out_tensor->flat<float>();
168
169 float* out = &(flat_tensor(0));
170
171 Kernel(context, batch_size, feat_dim, num_points, r * r * r, inp_inds,
172 inp_wgts, inp_grad_y, out);
173 }
174
175 virtual void Kernel(tensorflow::OpKernelContext* context,
176 int b,
177 int c,
178 int n,
179 int r3,
180 const int* inds,
181 const float* wgts,
182 const float* grad_y,
183 float* grad_x) = 0;
184
185protected:
186 int r;
187};
ImGuiContext * context
Definition: Window.cpp:95
Definition: TrilinearDevoxelizeKernel.h:119
int r
Definition: TrilinearDevoxelizeKernel.h:186
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:132
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r3, const int *inds, const float *wgts, const float *grad_y, float *grad_x)=0
TrilinearDevoxelizeGradOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:121
Definition: TrilinearDevoxelizeKernel.h:34
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r, int r2, int r3, bool training, const float *coords, const float *feat, int *inds, float *wgts, float *outs)=0
TrilinearDevoxelizeOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:36
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:47
int r
Definition: TrilinearDevoxelizeKernel.h:115
bool is_training
Definition: TrilinearDevoxelizeKernel.h:116