From d393146f19d70cbeb68db2f6c71396edae486bed Mon Sep 17 00:00:00 2001 From: Martin Chase Date: Fri, 7 Jun 2024 11:16:57 -0700 Subject: [PATCH] Transform3D.map should return the type of obj passed in (#3050) * Transform3D.map should return the type of obj passed in --- pyqtgraph/SRTTransform3D.py | 11 ----------- pyqtgraph/Transform3D.py | 9 ++++++--- tests/test_srttransform3d.py | 20 ++++++++++++++++++-- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/pyqtgraph/SRTTransform3D.py b/pyqtgraph/SRTTransform3D.py index a1c820feba..432e1cea5f 100644 --- a/pyqtgraph/SRTTransform3D.py +++ b/pyqtgraph/SRTTransform3D.py @@ -214,17 +214,6 @@ def update(self): def __repr__(self): return str(self.saveState()) - - def matrix(self, nd=3): - if nd == 3: - return np.array(self.copyDataTo()).reshape(4,4) - elif nd == 2: - m = np.array(self.copyDataTo()).reshape(4,4) - m[2] = m[3] - m[:,2] = m[:,3] - return m[:3,:3] - else: - raise Exception("Argument 'nd' must be 2 or 3") def __reduce__(self): return SRTTransform3D, (self.saveState(),) diff --git a/pyqtgraph/Transform3D.py b/pyqtgraph/Transform3D.py index 379dc3fb0b..354602bd7b 100644 --- a/pyqtgraph/Transform3D.py +++ b/pyqtgraph/Transform3D.py @@ -30,7 +30,7 @@ def matrix(self, nd=3): return m[:3,:3] else: raise Exception("Argument 'nd' must be 2 or 3") - + def map(self, obj): """ Extends QMatrix4x4.map() to allow mapping (3, ...) arrays of coordinates @@ -45,8 +45,11 @@ def map(self, obj): v = QtGui.QMatrix4x4.map(self, Vector(obj)) return type(obj)([v.x(), v.y(), v.z()])[:len(obj)] else: - return QtGui.QMatrix4x4.map(self, obj) - + retval = QtGui.QMatrix4x4.map(self, obj) + if not isinstance(retval, type(obj)): + return type(obj)(retval) + return retval + def inverted(self): inv, b = QtGui.QMatrix4x4.inverted(self) return Transform3D(inv), b diff --git a/tests/test_srttransform3d.py b/tests/test_srttransform3d.py index 6a90a6cc3d..28a04abefc 100644 --- a/tests/test_srttransform3d.py +++ b/tests/test_srttransform3d.py @@ -1,8 +1,9 @@ import numpy as np +import pytest from numpy.testing import assert_almost_equal, assert_array_almost_equal import pyqtgraph as pg -from pyqtgraph.Qt import QtGui +from pyqtgraph.Qt import QtCore, QtGui testPoints = np.array([ [0, 0, 0], @@ -27,7 +28,7 @@ def testMatrix(): tr2 = pg.Transform3D(tr) assert np.all(tr.matrix() == tr2.matrix()) - + # This is the most important test: # The transition from Transform3D to SRTTransform3D is a tricky one. tr3 = pg.SRTTransform3D(tr2) @@ -36,3 +37,18 @@ def testMatrix(): assert_array_almost_equal(tr3.getRotation()[1], tr.getRotation()[1]) assert_array_almost_equal(tr3.getScale(), tr.getScale()) assert_array_almost_equal(tr3.getTranslation(), tr.getTranslation()) + + +@pytest.mark.parametrize("v", [ + pg.Vector((0, 0, 0)), + QtGui.QVector3D(0, 0, 0), + np.array((0, 0, 0)), + QtCore.QPoint(0, 0), + QtCore.QPointF(0.0, 0.0), + (0, 0, 0), + [0, 0], +]) +def testMapTypes(v): + tr = pg.SRTTransform3D() + res = tr.map(v) + assert isinstance(res, type(v))