From ba029572d17a5dc57ada8941ee4f9767bc5849b5 Mon Sep 17 00:00:00 2001 From: yasmmin Date: Fri, 18 Aug 2023 18:39:22 -0300 Subject: [PATCH] fix when graph grakel is initialized with attributes nd vertex/edge histogram kernels are used --- grakel/kernels/edge_histogram.py | 8 ++++++-- grakel/kernels/vertex_histogram.py | 8 ++++++-- grakel/kernels/weisfeiler_lehman.py | 23 ++++++++++++++++++++--- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/grakel/kernels/edge_histogram.py b/grakel/kernels/edge_histogram.py index 283eccc9..72e1eac1 100644 --- a/grakel/kernels/edge_histogram.py +++ b/grakel/kernels/edge_histogram.py @@ -96,7 +96,11 @@ def parse_input(self, X): continue else: # Our element is an iterable of at least 2 elements - L = x[2] + elements = list(itervalues( x[2] )) + if( isinstance( elements[0], Iterable) ): + L = list( map( lambda y: y[0], elements )) + else: + L = elements elif type(x) is Graph: # get labels in any existing format L = x.get_labels(purpose="any", label_type="edge") @@ -107,7 +111,7 @@ def parse_input(self, X): 'dict \n') # construct the data input for the numpy array - for (label, frequency) in iteritems(Counter(itervalues(L))): + for (label, frequency) in iteritems(Counter( L )): # for the row that corresponds to that graph rows.append(ni) diff --git a/grakel/kernels/vertex_histogram.py b/grakel/kernels/vertex_histogram.py index b2cea060..f9fc1394 100644 --- a/grakel/kernels/vertex_histogram.py +++ b/grakel/kernels/vertex_histogram.py @@ -96,7 +96,11 @@ def parse_input(self, X): continue else: # Our element is an iterable of at least 2 elements - L = x[1] + elements = list(itervalues( x[1] )) + if( isinstance( elements[0], Iterable) ): + L = list( map( lambda y: y[0], elements )) + else: + L = elements elif type(x) is Graph: # get labels in any existing format L = x.get_labels(purpose="any") @@ -107,7 +111,7 @@ def parse_input(self, X): 'dict \n') # construct the data input for the numpy array - for (label, frequency) in iteritems(Counter(itervalues(L))): + for (label, frequency) in iteritems(Counter( L )): # for the row that corresponds to that graph rows.append(ni) diff --git a/grakel/kernels/weisfeiler_lehman.py b/grakel/kernels/weisfeiler_lehman.py index 49a3b033..d5447dd9 100644 --- a/grakel/kernels/weisfeiler_lehman.py +++ b/grakel/kernels/weisfeiler_lehman.py @@ -178,7 +178,16 @@ def parse_input(self, X): Gs_ed[nx] = x.get_edge_dictionary() L[nx] = x.get_labels(purpose="dictionary") extras[nx] = extra - distinct_values |= set(itervalues(L[nx])) + elements = list(itervalues(L[nx])) + if( isinstance( elements[0], Iterable) ): + el = list( map( lambda y: y[0], elements )) + else: + el = elements + temp={} + for k,v in zip( list(L[nx].keys()), el): + temp[k] = v + L[nx] = temp + distinct_values |= set( el ) nx += 1 if nx == 0: raise ValueError('parsed input is empty') @@ -357,10 +366,18 @@ def transform(self, X): 'least one and at most 3 elements\n') Gs_ed[nx] = x.get_edge_dictionary() L[nx] = x.get_labels(purpose="dictionary") - + elements = list(itervalues( L[nx] )) + if( isinstance( elements[0], Iterable) ): + el = list( map( lambda y: y[0], elements )) + else: + el = elements + temp={} + for k,v in zip( list(L[nx].keys()), el): + temp[k] = v + L[nx] = temp # Hold all the distinct values distinct_values |= set( - v for v in itervalues(L[nx]) + v for v in el if v not in self._inv_labels[0]) nx += 1 if nx == 0: