Open3D (C++ API)  0.15.1
TorchHelper.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// The MIT License (MIT)
5//
6// Copyright (c) 2018-2021 www.open3d.org
7//
8// Permission is hereby granted, free of charge, to any person obtaining a copy
9// of this software and associated documentation files (the "Software"), to deal
10// in the Software without restriction, including without limitation the rights
11// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12// copies of the Software, and to permit persons to whom the Software is
13// furnished to do so, subject to the following conditions:
14//
15// The above copyright notice and this permission notice shall be included in
16// all copies or substantial portions of the Software.
17//
18// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24// IN THE SOFTWARE.
25// ----------------------------------------------------------------------------
26
27#pragma once
28#include <torch/script.h>
29
30#include <sstream>
31#include <type_traits>
32
34
35// Macros for checking tensor properties
36#define CHECK_CUDA(x) \
37 do { \
38 TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
39 } while (0)
40
41#define CHECK_CONTIGUOUS(x) \
42 do { \
43 TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
44 } while (0)
45
46#define CHECK_TYPE(x, type) \
47 do { \
48 TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \
49 } while (0)
50
51#define CHECK_SAME_DEVICE_TYPE(...) \
52 do { \
53 if (!SameDeviceType({__VA_ARGS__})) { \
54 TORCH_CHECK( \
55 false, \
56 #__VA_ARGS__ \
57 " must all have the same device type but got " + \
58 TensorInfoStr({__VA_ARGS__})) \
59 } \
60 } while (0)
61
62#define CHECK_SAME_DTYPE(...) \
63 do { \
64 if (!SameDtype({__VA_ARGS__})) { \
65 TORCH_CHECK(false, \
66 #__VA_ARGS__ \
67 " must all have the same dtype but got " + \
68 TensorInfoStr({__VA_ARGS__})) \
69 } \
70 } while (0)
71
72// Conversion from standard types to torch types
73typedef std::remove_const<decltype(torch::kInt32)>::type TorchDtype_t;
74template <class T>
76 TORCH_CHECK(false, "Unsupported type");
77}
78template <>
80 return torch::kUInt8;
81}
82template <>
84 return torch::kInt8;
85}
86template <>
88 return torch::kInt16;
89}
90template <>
92 return torch::kInt32;
93}
94template <>
96 return torch::kInt64;
97}
98template <>
100 return torch::kFloat32;
101}
102template <>
104 return torch::kFloat64;
105}
106
107// convenience function for comparing standard types with torch types
108template <class T, class TDtype>
109inline bool CompareTorchDtype(const TDtype& t) {
110 return ToTorchDtype<T>() == t;
111}
112
113// convenience function to check if all tensors have the same device type
114inline bool SameDeviceType(std::initializer_list<torch::Tensor> tensors) {
115 if (tensors.size()) {
116 auto device_type = tensors.begin()->device().type();
117 for (auto t : tensors) {
118 if (device_type != t.device().type()) {
119 return false;
120 }
121 }
122 }
123 return true;
124}
125
126// convenience function to check if all tensors have the same dtype
127inline bool SameDtype(std::initializer_list<torch::Tensor> tensors) {
128 if (tensors.size()) {
129 auto dtype = tensors.begin()->dtype();
130 for (auto t : tensors) {
131 if (dtype != t.dtype()) {
132 return false;
133 }
134 }
135 }
136 return true;
137}
138
139inline std::string TensorInfoStr(std::initializer_list<torch::Tensor> tensors) {
140 std::stringstream sstr;
141 size_t count = 0;
142 for (const auto t : tensors) {
143 sstr << t.sizes() << " " << t.toString() << " " << t.device();
144 ++count;
145 if (count < tensors.size()) sstr << ", ";
146 }
147 return sstr.str();
148}
149
150// convenience function for creating a tensor for temp memory
151inline torch::Tensor CreateTempTensor(const int64_t size,
152 const torch::Device& device,
153 void** ptr = nullptr) {
154 torch::Tensor tensor = torch::empty(
155 {size}, torch::dtype(ToTorchDtype<uint8_t>()).device(device));
156 if (ptr) {
157 *ptr = tensor.data_ptr<uint8_t>();
158 }
159 return tensor;
160}
161
162inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
163 torch::Tensor tensor) {
164 using namespace open3d::ml::op_util;
165
166 std::vector<DimValue> shape;
167 const int rank = tensor.dim();
168 for (int i = 0; i < rank; ++i) {
169 shape.push_back(tensor.size(i));
170 }
171 return shape;
172}
173
175 class TDimX,
176 class... TArgs>
177std::tuple<bool, std::string> CheckShape(torch::Tensor tensor,
178 TDimX&& dimex,
179 TArgs&&... args) {
180 return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
181 std::forward<TDimX>(dimex),
182 std::forward<TArgs>(args)...);
183}
184
185//
186// Macros for checking the shape of Tensors.
187// Usage:
188// {
189// using namespace open3d::ml::op_util;
190// Dim w("w");
191// Dim h("h");
192// CHECK_SHAPE(tensor1, 10, w, h); // checks if the first dim is 10
193// // and assigns w and h based on
194// // the shape of tensor1
195//
196// CHECK_SHAPE(tensor2, 10, 20, h); // this checks if the the last dim
197// // of tensor2 matches the last dim
198// // of tensor1. The first two dims
199// // must match 10, 20.
200// }
201//
202//
203// See "../ShapeChecking.h" for more info and limitations.
204//
205#define CHECK_SHAPE(tensor, ...) \
206 do { \
207 bool cs_success_; \
208 std::string cs_errstr_; \
209 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
210 TORCH_CHECK(cs_success_, \
211 "invalid shape for '" #tensor "', " + cs_errstr_) \
212 } while (0)
213
214#define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \
215 do { \
216 bool cs_success_; \
217 std::string cs_errstr_; \
218 std::tie(cs_success_, cs_errstr_) = \
219 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
220 TORCH_CHECK(cs_success_, \
221 "invalid shape for '" #tensor "', " + cs_errstr_) \
222 } while (0)
223
224#define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \
225 do { \
226 bool cs_success_; \
227 std::string cs_errstr_; \
228 std::tie(cs_success_, cs_errstr_) = \
229 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
230 TORCH_CHECK(cs_success_, \
231 "invalid shape for '" #tensor "', " + cs_errstr_) \
232 } while (0)
233
234#define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \
235 do { \
236 bool cs_success_; \
237 std::string cs_errstr_; \
238 std::tie(cs_success_, cs_errstr_) = \
239 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
240 TORCH_CHECK(cs_success_, \
241 "invalid shape for '" #tensor "', " + cs_errstr_) \
242 } while (0)
243
244#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \
245 do { \
246 bool cs_success_; \
247 std::string cs_errstr_; \
248 std::tie(cs_success_, cs_errstr_) = \
249 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
250 TORCH_CHECK(cs_success_, \
251 "invalid shape for '" #tensor "', " + cs_errstr_) \
252 } while (0)
TorchDtype_t ToTorchDtype< int64_t >()
Definition: TorchHelper.h:95
TorchDtype_t ToTorchDtype< uint8_t >()
Definition: TorchHelper.h:79
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:139
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(torch::Tensor tensor)
Definition: TorchHelper.h:162
TorchDtype_t ToTorchDtype< int16_t >()
Definition: TorchHelper.h:87
TorchDtype_t ToTorchDtype< int8_t >()
Definition: TorchHelper.h:83
TorchDtype_t ToTorchDtype< double >()
Definition: TorchHelper.h:103
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:127
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:114
std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t
Definition: TorchHelper.h:73
TorchDtype_t ToTorchDtype()
Definition: TorchHelper.h:75
torch::Tensor CreateTempTensor(const int64_t size, const torch::Device &device, void **ptr=nullptr)
Definition: TorchHelper.h:151
std::tuple< bool, std::string > CheckShape(torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)
Definition: TorchHelper.h:177
TorchDtype_t ToTorchDtype< int32_t >()
Definition: TorchHelper.h:91
bool CompareTorchDtype(const TDtype &t)
Definition: TorchHelper.h:109
TorchDtype_t ToTorchDtype< float >()
Definition: TorchHelper.h:99
int size
Definition: FilePCD.cpp:59
int count
Definition: FilePCD.cpp:61
char type
Definition: FilePCD.cpp:60
Definition: ShapeChecking.h:35
CSOpt
Check shape options.
Definition: ShapeChecking.h:424