Kokkos Core Kernels Package Version of the Day
Loading...
Searching...
No Matches
Kokkos_Vector.hpp
1//@HEADER
2// ************************************************************************
3//
4// Kokkos v. 4.0
5// Copyright (2022) National Technology & Engineering
6// Solutions of Sandia, LLC (NTESS).
7//
8// Under the terms of Contract DE-NA0003525 with NTESS,
9// the U.S. Government retains certain rights in this software.
10//
11// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12// See https://kokkos.org/LICENSE for license information.
13// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14//
15//@HEADER
16
17#ifndef KOKKOS_VECTOR_HPP
18#define KOKKOS_VECTOR_HPP
19#ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
20#define KOKKOS_IMPL_PUBLIC_INCLUDE
21#define KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_VECTOR
22#endif
23
24#include <Kokkos_Macros.hpp>
25
26#if defined(KOKKOS_ENABLE_DEPRECATED_CODE_4)
27#if defined(KOKKOS_ENABLE_DEPRECATION_WARNINGS)
28namespace {
29[[deprecated("Deprecated <Kokkos_Vector.hpp> header is included")]] int
30emit_warning_kokkos_vector_deprecated() {
31 return 0;
32}
33static auto do_not_include = emit_warning_kokkos_vector_deprecated();
34} // namespace
35#endif
36#else
37#error "Deprecated <Kokkos_Vector.hpp> header is included"
38#endif
39
40#include <Kokkos_Core_fwd.hpp>
41#include <Kokkos_DualView.hpp>
42
43/* Drop in replacement for std::vector based on Kokkos::DualView
44 * Most functions only work on the host (it will not compile if called from
45 * device kernel)
46 *
47 */
48namespace Kokkos {
49
50#ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4
51template <class Scalar, class Arg1Type = void>
52class KOKKOS_DEPRECATED vector
53 : public DualView<Scalar*, LayoutLeft, Arg1Type> {
54 public:
55 using value_type = Scalar;
56 using pointer = Scalar*;
57 using const_pointer = const Scalar*;
58 using reference = Scalar&;
59 using const_reference = const Scalar&;
60 using iterator = Scalar*;
61 using const_iterator = const Scalar*;
62 using size_type = size_t;
63
64 private:
65 size_t _size;
66 float _extra_storage;
67 using DV = DualView<Scalar*, LayoutLeft, Arg1Type>;
68
69 public:
70#ifdef KOKKOS_ENABLE_CUDA_UVM
71 KOKKOS_INLINE_FUNCTION reference operator()(int i) const {
72 return DV::view_host()(i);
73 };
74 KOKKOS_INLINE_FUNCTION reference operator[](int i) const {
75 return DV::view_host()(i);
76 };
77#else
78 inline reference operator()(int i) const { return DV::view_host()(i); }
79 inline reference operator[](int i) const { return DV::view_host()(i); }
80#endif
81
82 /* Member functions which behave like std::vector functions */
83
84 vector() : DV() {
85 _size = 0;
86 _extra_storage = 1.1;
87 }
88
89 vector(int n, Scalar val = Scalar())
90 : DualView<Scalar*, LayoutLeft, Arg1Type>("Vector", size_t(n * (1.1))) {
91 _size = n;
92 _extra_storage = 1.1;
93 DV::modified_flags(0) = 1;
94
95 assign(n, val);
96 }
97
98 void resize(size_t n) {
99 if (n >= span()) DV::resize(size_t(n * _extra_storage));
100 _size = n;
101 }
102
103 void resize(size_t n, const Scalar& val) { assign(n, val); }
104
105 void assign(size_t n, const Scalar& val) {
106 /* Resize if necessary (behavior of std:vector) */
107
108 if (n > span()) DV::resize(size_t(n * _extra_storage));
109 _size = n;
110
111 /* Assign value either on host or on device */
112
113 if (DV::template need_sync<typename DV::t_dev::device_type>()) {
114 set_functor_host f(DV::view_host(), val);
115 parallel_for("Kokkos::vector::assign", n, f);
116 typename DV::t_host::execution_space().fence(
117 "Kokkos::vector::assign: fence after assigning values");
118 DV::template modify<typename DV::t_host::device_type>();
119 } else {
120 set_functor f(DV::view_device(), val);
121 parallel_for("Kokkos::vector::assign", n, f);
122 typename DV::t_dev::execution_space().fence(
123 "Kokkos::vector::assign: fence after assigning values");
124 DV::template modify<typename DV::t_dev::device_type>();
125 }
126 }
127
128 void reserve(size_t n) { DV::resize(size_t(n * _extra_storage)); }
129
130 void push_back(Scalar val) {
131 if (_size == span()) {
132 size_t new_size = _size * _extra_storage;
133 if (new_size == _size) new_size++;
134 DV::resize(new_size);
135 }
136
137 DV::sync_host();
138 DV::view_host()(_size) = val;
139 _size++;
140 DV::modify_host();
141 }
142
143 void pop_back() { _size--; }
144
145 void clear() { _size = 0; }
146
147 iterator insert(iterator it, const value_type& val) {
148 return insert(it, 1, val);
149 }
150
151 iterator insert(iterator it, size_type count, const value_type& val) {
152 if ((size() == 0) && (it == begin())) {
153 resize(count, val);
154 DV::sync_host();
155 return begin();
156 }
157 DV::sync_host();
158 DV::modify_host();
159 if (std::less<>()(it, begin()) || std::less<>()(end(), it))
160 Kokkos::abort("Kokkos::vector::insert : invalid insert iterator");
161 if (count == 0) return it;
162 ptrdiff_t start = std::distance(begin(), it);
163 auto org_size = size();
164 resize(size() + count);
165
166 std::copy_backward(begin() + start, begin() + org_size,
167 begin() + org_size + count);
168 std::fill_n(begin() + start, count, val);
169
170 return begin() + start;
171 }
172
173 private:
174 template <class T>
175 struct impl_is_input_iterator : /* TODO replace this */ std::bool_constant<
176 !std::is_convertible_v<T, size_type>> {};
177
178 public:
179 // TODO: can use detection idiom to generate better error message here later
180 template <typename InputIterator>
181 std::enable_if_t<impl_is_input_iterator<InputIterator>::value, iterator>
182 insert(iterator it, InputIterator b, InputIterator e) {
183 ptrdiff_t count = std::distance(b, e);
184
185 DV::sync_host();
186 DV::modify_host();
187 if (std::less<>()(it, begin()) || std::less<>()(end(), it))
188 Kokkos::abort("Kokkos::vector::insert : invalid insert iterator");
189
190 ptrdiff_t start = std::distance(begin(), it);
191 auto org_size = size();
192
193 // Note: resize(...) invalidates it; use begin() + start instead
194 resize(size() + count);
195
196 std::copy_backward(begin() + start, begin() + org_size,
197 begin() + org_size + count);
198 std::copy(b, e, begin() + start);
199
200 return begin() + start;
201 }
202
203 KOKKOS_INLINE_FUNCTION constexpr bool is_allocated() const {
204 return DV::is_allocated();
205 }
206
207 size_type size() const { return _size; }
208 size_type max_size() const { return 2000000000; }
209 size_type span() const { return DV::span(); }
210 bool empty() const { return _size == 0; }
211
212 pointer data() const { return DV::view_host().data(); }
213
214 iterator begin() const { return DV::view_host().data(); }
215
216 const_iterator cbegin() const { return DV::view_host().data(); }
217
218 iterator end() const {
219 return _size > 0 ? DV::view_host().data() + _size : DV::view_host().data();
220 }
221
222 const_iterator cend() const {
223 return _size > 0 ? DV::view_host().data() + _size : DV::view_host().data();
224 }
225
226 reference front() { return DV::view_host()(0); }
227
228 reference back() { return DV::view_host()(_size - 1); }
229
230 const_reference front() const { return DV::view_host()(0); }
231
232 const_reference back() const { return DV::view_host()(_size - 1); }
233
234 /* std::algorithms which work originally with iterators, here they are
235 * implemented as member functions */
236
237 size_t lower_bound(const size_t& start, const size_t& theEnd,
238 const Scalar& comp_val) const {
239 int lower = start; // FIXME (mfh 24 Apr 2014) narrowing conversion
240 int upper =
241 _size > theEnd
242 ? theEnd
243 : _size - 1; // FIXME (mfh 24 Apr 2014) narrowing conversion
244 if (upper <= lower) {
245 return theEnd;
246 }
247
248 Scalar lower_val = DV::view_host()(lower);
249 Scalar upper_val = DV::view_host()(upper);
250 size_t idx = (upper + lower) / 2;
251 Scalar val = DV::view_host()(idx);
252 if (val > upper_val) return upper;
253 if (val < lower_val) return start;
254
255 while (upper > lower) {
256 if (comp_val > val) {
257 lower = ++idx;
258 } else {
259 upper = idx;
260 }
261 idx = (upper + lower) / 2;
262 val = DV::view_host()(idx);
263 }
264 return idx;
265 }
266
267 bool is_sorted() {
268 for (int i = 0; i < _size - 1; i++) {
269 if (DV::view_host()(i) > DV::view_host()(i + 1)) return false;
270 }
271 return true;
272 }
273
274 iterator find(Scalar val) const {
275 if (_size == 0) return end();
276
277 int upper, lower, current;
278 current = _size / 2;
279 upper = _size - 1;
280 lower = 0;
281
282 if ((val < DV::view_host()(0)) || (val > DV::view_host()(_size - 1)))
283 return end();
284
285 while (upper > lower) {
286 if (val > DV::view_host()(current))
287 lower = current + 1;
288 else
289 upper = current;
290 current = (upper + lower) / 2;
291 }
292
293 if (val == DV::view_host()(current))
294 return &DV::view_host()(current);
295 else
296 return end();
297 }
298
299 /* Additional functions for data management */
300
301 void device_to_host() { deep_copy(DV::view_host(), DV::view_device()); }
302 void host_to_device() const { deep_copy(DV::view_device(), DV::view_host()); }
303
304 void on_host() { DV::template modify<typename DV::t_host::device_type>(); }
305 void on_device() { DV::template modify<typename DV::t_dev::device_type>(); }
306
307 void set_overallocation(float extra) { _extra_storage = 1.0 + extra; }
308
309 public:
310 struct set_functor {
311 using execution_space = typename DV::t_dev::execution_space;
312 typename DV::t_dev _data;
313 Scalar _val;
314
315 set_functor(typename DV::t_dev data, Scalar val) : _data(data), _val(val) {}
316
317 KOKKOS_INLINE_FUNCTION
318 void operator()(const int& i) const { _data(i) = _val; }
319 };
320
321 struct set_functor_host {
322 using execution_space = typename DV::t_host::execution_space;
323 typename DV::t_host _data;
324 Scalar _val;
325
326 set_functor_host(typename DV::t_host data, Scalar val)
327 : _data(data), _val(val) {}
328
329 KOKKOS_INLINE_FUNCTION
330 void operator()(const int& i) const { _data(i) = _val; }
331 };
332};
333#endif
334
335} // namespace Kokkos
336#ifdef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_VECTOR
337#undef KOKKOS_IMPL_PUBLIC_INCLUDE
338#undef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_VECTOR
339#endif
340#endif
Declaration and definition of Kokkos::DualView.