forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_tensor_str.py
316 lines (268 loc) · 12.6 KB
/
_tensor_str.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import math
import torch
from torch._six import inf
class __PrinterOptions(object):
precision = 4
threshold = 1000
edgeitems = 3
linewidth = 80
sci_mode = None
PRINT_OPTS = __PrinterOptions()
# We could use **kwargs, but this will give better docs
def set_printoptions(
precision=None,
threshold=None,
edgeitems=None,
linewidth=None,
profile=None,
sci_mode=None
):
r"""Set options for printing. Items shamelessly taken from NumPy
Args:
precision: Number of digits of precision for floating point output
(default = 4).
threshold: Total number of array elements which trigger summarization
rather than full `repr` (default = 1000).
edgeitems: Number of array items in summary at beginning and end of
each dimension (default = 3).
linewidth: The number of characters per line for the purpose of
inserting line breaks (default = 80). Thresholded matrices will
ignore this parameter.
profile: Sane defaults for pretty printing. Can override with any of
the above options. (any one of `default`, `short`, `full`)
sci_mode: Enable (True) or disable (False) scientific notation. If
None (default) is specified, the value is defined by `_Formatter`
"""
if profile is not None:
if profile == "default":
PRINT_OPTS.precision = 4
PRINT_OPTS.threshold = 1000
PRINT_OPTS.edgeitems = 3
PRINT_OPTS.linewidth = 80
elif profile == "short":
PRINT_OPTS.precision = 2
PRINT_OPTS.threshold = 1000
PRINT_OPTS.edgeitems = 2
PRINT_OPTS.linewidth = 80
elif profile == "full":
PRINT_OPTS.precision = 4
PRINT_OPTS.threshold = inf
PRINT_OPTS.edgeitems = 3
PRINT_OPTS.linewidth = 80
if precision is not None:
PRINT_OPTS.precision = precision
if threshold is not None:
PRINT_OPTS.threshold = threshold
if edgeitems is not None:
PRINT_OPTS.edgeitems = edgeitems
if linewidth is not None:
PRINT_OPTS.linewidth = linewidth
PRINT_OPTS.sci_mode = sci_mode
class _Formatter(object):
def __init__(self, tensor):
self.floating_dtype = tensor.dtype.is_floating_point
self.int_mode = True
self.sci_mode = False
self.max_width = 1
with torch.no_grad():
tensor_view = tensor.reshape(-1)
if not self.floating_dtype:
for value in tensor_view:
value_str = '{}'.format(value)
self.max_width = max(self.max_width, len(value_str))
else:
nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0))
if nonzero_finite_vals.numel() == 0:
# no valid number, do nothing
return
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
nonzero_finite_abs = nonzero_finite_vals.abs().double()
nonzero_finite_min = nonzero_finite_abs.min().double()
nonzero_finite_max = nonzero_finite_abs.max().double()
for value in nonzero_finite_vals:
if value != torch.ceil(value):
self.int_mode = False
break
if self.int_mode:
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
if nonzero_finite_max / nonzero_finite_min > 1000. or nonzero_finite_max > 1.e8:
self.sci_mode = True
for value in nonzero_finite_vals:
value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value)
self.max_width = max(self.max_width, len(value_str))
else:
for value in nonzero_finite_vals:
value_str = ('{:.0f}').format(value)
self.max_width = max(self.max_width, len(value_str) + 1)
else:
# Check if scientific representation should be used.
if nonzero_finite_max / nonzero_finite_min > 1000.\
or nonzero_finite_max > 1.e8\
or nonzero_finite_min < 1.e-4:
self.sci_mode = True
for value in nonzero_finite_vals:
value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value)
self.max_width = max(self.max_width, len(value_str))
else:
for value in nonzero_finite_vals:
value_str = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value)
self.max_width = max(self.max_width, len(value_str))
if PRINT_OPTS.sci_mode is not None:
self.sci_mode = PRINT_OPTS.sci_mode
def width(self):
return self.max_width
def format(self, value):
if self.floating_dtype:
if self.sci_mode:
ret = ('{{:{}.{}e}}').format(self.max_width, PRINT_OPTS.precision).format(value)
elif self.int_mode:
ret = '{:.0f}'.format(value)
if not (math.isinf(value) or math.isnan(value)):
ret += '.'
else:
ret = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value)
else:
ret = '{}'.format(value)
return (self.max_width - len(ret)) * ' ' + ret
def _scalar_str(self, formatter):
return formatter.format(self.item())
def _vector_str(self, indent, formatter, summarize):
# length includes spaces and comma between elements
element_length = formatter.width() + 2
elements_per_line = max(1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))))
char_per_line = element_length * elements_per_line
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
data = ([formatter.format(val) for val in self[:PRINT_OPTS.edgeitems].tolist()] +
[' ...'] +
[formatter.format(val) for val in self[-PRINT_OPTS.edgeitems:].tolist()])
else:
data = [formatter.format(val) for val in self.tolist()]
data_lines = [data[i:i + elements_per_line] for i in range(0, len(data), elements_per_line)]
lines = [', '.join(line) for line in data_lines]
return '[' + (',' + '\n' + ' ' * (indent + 1)).join(lines) + ']'
def _tensor_str_with_formatter(self, indent, formatter, summarize):
dim = self.dim()
if dim == 0:
return _scalar_str(self, formatter)
if dim == 1:
return _vector_str(self, indent, formatter, summarize)
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
slices = ([_tensor_str_with_formatter(self[i], indent + 1, formatter, summarize)
for i in range(0, PRINT_OPTS.edgeitems)] +
['...'] +
[_tensor_str_with_formatter(self[i], indent + 1, formatter, summarize)
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
else:
slices = [_tensor_str_with_formatter(self[i], indent + 1, formatter, summarize)
for i in range(0, self.size(0))]
tensor_str = (',' + '\n' * (dim - 1) + ' ' * (indent + 1)).join(slices)
return '[' + tensor_str + ']'
def _tensor_str(self, indent):
if self.numel() == 0:
return '[]'
summarize = self.numel() > PRINT_OPTS.threshold
if self.dtype is torch.float16 or self.dtype is torch.bfloat16:
self = self.float()
formatter = _Formatter(get_summarized_data(self) if summarize else self)
return _tensor_str_with_formatter(self, indent, formatter, summarize)
def _add_suffixes(tensor_str, suffixes, indent, force_newline):
tensor_strs = [tensor_str]
last_line_len = len(tensor_str) - tensor_str.rfind('\n') + 1
for suffix in suffixes:
suffix_len = len(suffix)
if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
tensor_strs.append(',\n' + ' ' * indent + suffix)
last_line_len = indent + suffix_len
force_newline = False
else:
tensor_strs.append(', ' + suffix)
last_line_len += suffix_len + 2
tensor_strs.append(')')
return ''.join(tensor_strs)
def get_summarized_data(self):
dim = self.dim()
if dim == 0:
return self
if dim == 1:
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
return torch.cat((self[:PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems:]))
else:
return self
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
end = ([self[i]
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
return torch.stack([get_summarized_data(x) for x in (start + end)])
else:
return torch.stack([get_summarized_data(x) for x in self])
def _str(self):
prefix = 'tensor('
indent = len(prefix)
suffixes = []
# Note [Print tensor device]:
# A general logic here is we only print device when it doesn't match
# the device specified in default tensor type.
# Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
# torch._C._get_default_device() only returns either cpu or cuda.
# In other cases, we don't have a way to set them as default yet,
# and we should always print out device for them.
if self.device.type != torch._C._get_default_device()\
or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index):
suffixes.append('device=\'' + str(self.device) + '\'')
has_default_dtype = self.dtype in (torch.get_default_dtype(), torch.int64, torch.bool)
if self.is_sparse:
suffixes.append('size=' + str(tuple(self.shape)))
suffixes.append('nnz=' + str(self._nnz()))
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
indices_prefix = 'indices=tensor('
indices = self._indices().detach()
indices_str = _tensor_str(indices, indent + len(indices_prefix))
if indices.numel() == 0:
indices_str += ', size=' + str(tuple(indices.shape))
values_prefix = 'values=tensor('
values = self._values().detach()
values_str = _tensor_str(values, indent + len(values_prefix))
if values.numel() == 0:
values_str += ', size=' + str(tuple(values.shape))
tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
elif self.is_quantized:
suffixes.append('size=' + str(tuple(self.shape)))
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
# TODO: change to a call to self.q_scheme() when we add q_scheme method
# and uncomment this
# suffixes.append('quantization_scheme=' + 'per_tensor_affine')
suffixes.append('scale=' + str(self.q_scale()))
suffixes.append('zero_point=' + str(self.q_zero_point()))
tensor_str = _tensor_str(self.dequantize(), indent)
else:
if self.numel() == 0 and not self.is_sparse:
# Explicitly print the shape if it is not (0,), to match NumPy behavior
if self.dim() != 1:
suffixes.append('size=' + str(tuple(self.shape)))
# In an empty tensor, there are no elements to infer if the dtype
# should be int64, so it must be shown explicitly.
if self.dtype != torch.get_default_dtype():
suffixes.append('dtype=' + str(self.dtype))
tensor_str = '[]'
else:
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
if self.layout != torch.strided:
tensor_str = _tensor_str(self.to_dense(), indent)
else:
tensor_str = _tensor_str(self, indent)
if self.layout != torch.strided:
suffixes.append('layout=' + str(self.layout))
if self.grad_fn is not None:
name = type(self.grad_fn).__name__
if name == 'CppFunction':
name = self.grad_fn.name().rsplit('::', 1)[-1]
suffixes.append('grad_fn=<{}>'.format(name))
elif self.requires_grad:
suffixes.append('requires_grad=True')
if torch._C._BUILD_NAMEDTENSOR and self.has_names():
suffixes.append('names={}'.format(self.names))
return _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)