Skip to content

Commit

Permalink
Transform3D.map should return the type of obj passed in (pyqtgraph#3050)
Browse files Browse the repository at this point in the history
* Transform3D.map should return the type of obj passed in
  • Loading branch information
outofculture authored Jun 7, 2024
1 parent 5d89a8b commit d393146
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
11 changes: 0 additions & 11 deletions pyqtgraph/SRTTransform3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,17 +214,6 @@ def update(self):

def __repr__(self):
return str(self.saveState())

def matrix(self, nd=3):
if nd == 3:
return np.array(self.copyDataTo()).reshape(4,4)
elif nd == 2:
m = np.array(self.copyDataTo()).reshape(4,4)
m[2] = m[3]
m[:,2] = m[:,3]
return m[:3,:3]
else:
raise Exception("Argument 'nd' must be 2 or 3")

def __reduce__(self):
return SRTTransform3D, (self.saveState(),)
9 changes: 6 additions & 3 deletions pyqtgraph/Transform3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def matrix(self, nd=3):
return m[:3,:3]
else:
raise Exception("Argument 'nd' must be 2 or 3")

def map(self, obj):
"""
Extends QMatrix4x4.map() to allow mapping (3, ...) arrays of coordinates
Expand All @@ -45,8 +45,11 @@ def map(self, obj):
v = QtGui.QMatrix4x4.map(self, Vector(obj))
return type(obj)([v.x(), v.y(), v.z()])[:len(obj)]
else:
return QtGui.QMatrix4x4.map(self, obj)

retval = QtGui.QMatrix4x4.map(self, obj)
if not isinstance(retval, type(obj)):
return type(obj)(retval)
return retval

def inverted(self):
inv, b = QtGui.QMatrix4x4.inverted(self)
return Transform3D(inv), b
20 changes: 18 additions & 2 deletions tests/test_srttransform3d.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import pytest
from numpy.testing import assert_almost_equal, assert_array_almost_equal

import pyqtgraph as pg
from pyqtgraph.Qt import QtGui
from pyqtgraph.Qt import QtCore, QtGui

testPoints = np.array([
[0, 0, 0],
Expand All @@ -27,7 +28,7 @@ def testMatrix():

tr2 = pg.Transform3D(tr)
assert np.all(tr.matrix() == tr2.matrix())

# This is the most important test:
# The transition from Transform3D to SRTTransform3D is a tricky one.
tr3 = pg.SRTTransform3D(tr2)
Expand All @@ -36,3 +37,18 @@ def testMatrix():
assert_array_almost_equal(tr3.getRotation()[1], tr.getRotation()[1])
assert_array_almost_equal(tr3.getScale(), tr.getScale())
assert_array_almost_equal(tr3.getTranslation(), tr.getTranslation())


@pytest.mark.parametrize("v", [
pg.Vector((0, 0, 0)),
QtGui.QVector3D(0, 0, 0),
np.array((0, 0, 0)),
QtCore.QPoint(0, 0),
QtCore.QPointF(0.0, 0.0),
(0, 0, 0),
[0, 0],
])
def testMapTypes(v):
tr = pg.SRTTransform3D()
res = tr.map(v)
assert isinstance(res, type(v))

0 comments on commit d393146

Please sign in to comment.