forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Dict_inl.h
209 lines (174 loc) · 7.36 KB
/
Dict_inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#pragma once
#include <ATen/core/ivalue.h>
#include <c10/util/hash.h>
namespace c10 {
namespace detail {
inline bool DictKeyEqualTo::operator()(const IValue& lhs, const IValue& rhs) const {
if (lhs.isTensor() && rhs.isTensor()) {
// for tensors, we compare only by identity (following how it's done in Python).
return lhs.is(rhs);
}
// Otherwise, we first compare by identity for efficiency, then by value (see:
// [container equality])
return _fastEqualsForContainer(lhs, rhs);
}
}
template<class T> decltype(auto) getTypePtr();
std::string toString(const Type& type);
namespace impl {
template<class Key, class Value>
Dict<Key, Value> toTypedDict(GenericDict dict) {
TORCH_INTERNAL_ASSERT(*getTypePtr<Key>() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Key types mismatch.");
TORCH_INTERNAL_ASSERT(*getTypePtr<Value>() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Value types mismatch.");
return Dict<Key, Value>(std::move(dict.impl_));
}
template<class Key, class Value>
GenericDict toGenericDict(Dict<Key, Value> dict) {
return GenericDict(std::move(dict.impl_));
}
}
namespace detail {
inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
if (ivalue.isInt()) {
return std::hash<int64_t>()(ivalue.toInt());
} else if (ivalue.isString()) {
return std::hash<c10::string_view>()(ivalue.toStringView());
} else if (ivalue.isDouble()) {
return std::hash<double>()(ivalue.toDouble());
} else if (ivalue.isComplexDouble()) {
return c10::hash<c10::complex<double>>()(ivalue.toComplexDouble());
} else if (ivalue.isBool()) {
return std::hash<bool>()(ivalue.toBool());
} else if (ivalue.isTensor()) {
return std::hash<TensorImpl*>()(ivalue.toTensor().unsafeGetTensorImpl());
} else if (ivalue.isDevice()) {
return std::hash<Device>()(ivalue.toDevice());
} else {
throw std::runtime_error(
"Can't hash IValues with tag '" + ivalue.tagKind() + "'");
}
}
inline intrusive_ptr<DictImpl> DictImpl::copy() const {
return make_intrusive<DictImpl>(dict, elementTypes);
}
}
template<class Key, class Value>
Dict<Key, Value>::Dict()
:Dict(make_intrusive<detail::DictImpl>(
detail::DictImpl::dict_map_type(),
detail::DictImpl::DictElementTypes{getTypePtr<Key>(), getTypePtr<Value>()})) {
static_assert(!std::is_same<Key, IValue>::value, "This constructor is not valid for Dict<IValue, _>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
static_assert(!std::is_same<Value, IValue>::value, "This constructor is not valid for Dict<_, IValue>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
}
template<class Key, class Value>
Dict<Key, Value>::Dict(TypePtr keyType, TypePtr valueType)
: Dict(make_intrusive<detail::DictImpl>(
detail::DictImpl::dict_map_type(),
detail::DictImpl::DictElementTypes {std::move(keyType), std::move(valueType)})) {
static_assert(std::is_same<Key, IValue>::value, "This constructor is only valid for c10::impl::GenericDict.");
static_assert(std::is_same<Value, IValue>::value, "This constructor is only valid for c10::impl::GenericDict.");
}
template<class Key, class Value>
Dict<Key, Value>::Dict(c10::intrusive_ptr<detail::DictImpl>&& impl): impl_(std::move(impl)) {}
template<class Key, class Value>
Dict<Key, Value> Dict<Key, Value>::copy() const {
return Dict<Key, Value>(impl_->copy());
}
template<class Key, class Value>
typename Dict<Key, Value>::iterator Dict<Key, Value>::begin() const {
return iterator{impl_->dict.begin()};
}
template<class Key, class Value>
typename Dict<Key, Value>::iterator Dict<Key, Value>::end() const {
return iterator{impl_->dict.end()};
}
template<class Key, class Value>
bool Dict<Key, Value>::empty() const {
return impl_->dict.empty();
}
template<class Key, class Value>
typename Dict<Key, Value>::size_type Dict<Key, Value>::size() const {
return impl_->dict.size();
}
template<class Key, class Value>
void Dict<Key, Value>::clear() const {
impl_->dict.clear();
}
template<class Key, class Value>
template<class Key_, class Value_>
std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert(Key_&& key, Value_&& value) const {
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert");
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert");
auto inserted = impl_->dict.insert(std::pair<IValue, IValue>{
Key(std::forward<Key_>(key)),
Value(std::forward<Value_>(value))});
return {iterator{inserted.first}, inserted.second};
}
template<class Key, class Value>
template<class Key_, class Value_>
std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert_or_assign(Key_&& key, Value_&& value) const {
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert_or_assign");
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert_or_assign");
auto inserted = impl_->dict.insert_or_assign(
Key(std::forward<Key_>(key)),
Value(std::forward<Value_>(value)));
return {iterator{inserted.first}, inserted.second};
}
template<class Key, class Value>
void Dict<Key, Value>::erase(iterator iter) const {
impl_->dict.erase(iter.entryRef_.iterator_);
}
template<class Key, class Value>
C10_NODISCARD size_t Dict<Key, Value>::erase(const Key& key) const {
return impl_->dict.erase(key);
}
template<class Key, class Value>
Value Dict<Key, Value>::at(const Key& key) const {
return impl_->dict.at(key).template to<Value>();
}
template<class Key, class Value>
typename Dict<Key, Value>::iterator Dict<Key, Value>::find(const Key& key) const {
return iterator{impl_->dict.find(key)};
}
template<class Key, class Value>
bool Dict<Key, Value>::contains(const Key& key) const {
return end() != find(key);
}
template<class Key, class Value>
void Dict<Key, Value>::reserve(size_type count) const {
impl_->dict.reserve(count);
}
template<class Key, class Value>
TypePtr Dict<Key, Value>::keyType() const {
return impl_->elementTypes.keyType;
}
template<class Key, class Value>
TypePtr Dict<Key, Value>::valueType() const {
return impl_->elementTypes.valueType;
}
template <class Key, class Value>
void Dict<Key, Value>::unsafeSetKeyType(TypePtr t) {
impl_->elementTypes.keyType = std::move(t);
}
template <class Key, class Value>
void Dict<Key, Value>::unsafeSetValueType(TypePtr t) {
impl_->elementTypes.valueType = std::move(t);
}
template <class Key_, class Value_>
bool operator==(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
// Dicts with the same identity trivially compare equal.
if (lhs.impl_ == rhs.impl_) {
return true;
}
// Otherwise compare the values
return *lhs.impl_ == *rhs.impl_;
}
template <class Key_, class Value_>
bool operator!=(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
return !(lhs == rhs);
}
template <class Key, class Value>
bool Dict<Key, Value>::is(const Dict& rhs) const {
return this->impl_ == rhs.impl_;
}
}