From 75b48995ba7d464cfe830337d79c78a0936c9f28 Mon Sep 17 00:00:00 2001 From: Michael Milton Date: Mon, 20 Nov 2017 11:20:52 +1100 Subject: [PATCH] Clean up column subset code; usage of set methods, method type signature, removed unusued code --- pandas_schema/schema.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/pandas_schema/schema.py b/pandas_schema/schema.py index 486eae3..5c0442e 100644 --- a/pandas_schema/schema.py +++ b/pandas_schema/schema.py @@ -10,6 +10,7 @@ class Schema: """ A schema that defines the columns required in the target DataFrame """ + def __init__(self, columns: typing.Iterable[Column], ordered: bool = False): """ :param columns: A list of column objects @@ -28,35 +29,39 @@ def __init__(self, columns: typing.Iterable[Column], ordered: bool = False): self.columns = list(columns) self.ordered = ordered - def validate(self, df: pd.DataFrame, columns: typing.List[Column]=None) -> typing.List[ValidationWarning]: + def validate(self, df: pd.DataFrame, columns: typing.List[str] = None) -> typing.List[ValidationWarning]: """ Runs a full validation of the target DataFrame using the internal columns list :param df: A pandas DataFrame to validate + :param columns: A list of columns indicating a subset of the schema that we want to validate :return: A list of ValidationWarning objects that list the ways in which the DataFrame was invalid """ errors = [] df_cols = len(df.columns) - # If no columns are passed, validate against every column in the schema - # This is the default behaviour + # If no columns are passed, validate against every column in the schema. This is the default behaviour if columns is None: schema_cols = len(self.columns) columns_to_pair = self.columns if df_cols != schema_cols: - errors.append(ValidationWarning('Invalid number of columns. The schema specifies {}, but the data frame has {}'.format(schema_cols, - df_cols))) + errors.append( + ValidationWarning( + 'Invalid number of columns. The schema specifies {}, but the data frame has {}'.format( + schema_cols, + df_cols) + ) + ) return errors - # Else check that columns passed in as an argument are part of the - # current schema, else raise an error + # If we did pass in columns, check that they are part of the current schema else: - if set(self.get_column_names()).intersection(columns) == set(columns): - schema_cols = len(columns) + if set(columns).issubset(self.get_column_names()): columns_to_pair = [column for column in self.columns if column.name in columns] else: - raise PanSchArgumentError('Columns {} passed in are not part of the schema'.format( - set(columns).difference(self.columns))) + raise PanSchArgumentError( + 'Columns {} passed in are not part of the schema'.format(set(columns).difference(self.columns)) + ) # We associate the column objects in the schema with data frame series either by name or by position, depending # on the value of self.ordered @@ -69,11 +74,11 @@ def validate(self, df: pd.DataFrame, columns: typing.List[Column]=None) -> typin # Throw an error if the schema column isn't in the data frame if column.name not in df: - errors.append(ValidationWarning('The column {} exists in the schema but not in the data frame'.format(column.name))) + errors.append(ValidationWarning( + 'The column {} exists in the schema but not in the data frame'.format(column.name))) return errors column_pairs.append((df[column.name], column)) - # Iterate over each pair of schema columns and data frame series and run validations for series, column in column_pairs: