diff --git a/src/ggrc/converters/base.py b/src/ggrc/converters/base.py index 149fdb825d23..9b7982c96856 100644 --- a/src/ggrc/converters/base.py +++ b/src/ggrc/converters/base.py @@ -48,6 +48,7 @@ class ImportConverter(BaseConverter): "delete", "task_type", "audit", + "assessment_template", ] def __init__(self, dry_run=True, csv_data=None): @@ -81,15 +82,10 @@ def import_csv_data(self): revision_ids = [] for converter in self.initialize_block_converters(): - converter.row_converters_from_csv() - for attr_name in self.priority_columns: - converter.handle_row_data(attr_name) - converter.handle_row_data() - converter.import_objects() - converter.import_secondary_objects() - + if not converter.ignore: + converter.import_csv_data() + revision_ids.extend(converter.revision_ids) self.response_data.append(converter.get_info()) - revision_ids.extend(converter.revision_ids) self._start_compute_attributes_job(revision_ids) self.drop_cache() diff --git a/src/ggrc/converters/base_block.py b/src/ggrc/converters/base_block.py index b6bdc5f6f5a5..faae0f628958 100644 --- a/src/ggrc/converters/base_block.py +++ b/src/ggrc/converters/base_block.py @@ -13,10 +13,8 @@ from collections import Counter from cached_property import cached_property -from sqlalchemy import exc from sqlalchemy import or_ from sqlalchemy import and_ -from sqlalchemy.orm.exc import UnmappedInstanceError from flask import _app_ctx_stack from ggrc import db @@ -27,26 +25,17 @@ from ggrc.utils import list_chunks from ggrc.converters import errors from ggrc.converters import get_shared_unique_rules -from ggrc.converters import pre_commit_checks from ggrc.converters import base_row from ggrc.converters.import_helper import get_column_order from ggrc.converters.import_helper import get_object_column_definitions -from ggrc.services.common import get_modified_objects -from ggrc.services.common import update_snapshot_index -from ggrc.cache import utils as cache_utils -from ggrc.utils.log_event import log_event from ggrc.services import signals from ggrc_workflows.models.cycle_task_group_object_task import \ CycleTaskGroupObjectTask -from ggrc.models.exceptions import StatusValidationError - # pylint: disable=invalid-name logger = getLogger(__name__) -CACHE_EXPIRY_IMPORT = 600 - class BlockConverter(object): # pylint: disable=too-many-public-methods @@ -76,6 +65,9 @@ class BlockConverter(object): "valid_values": "list of valid values" """ + + CACHE_EXPIRY_IMPORT = 600 + def __init__(self, converter, object_class, class_name, operation, object_ids=None, raw_headers=None, offset=None, rows=None): @@ -108,7 +100,9 @@ def __init__(self, converter, object_class, class_name, self.class_name = class_name # TODO: remove 'if' statement. Init should initialize only. if self.object_class: - self.object_headers = get_object_column_definitions(self.object_class) + names = {n.strip().strip("*").lower() for n in raw_headers or []} or None + self.object_headers = get_object_column_definitions(self.object_class, + names) if not raw_headers: all_header_names = [unicode(key) for key in self._get_header_names().keys()] @@ -430,8 +424,9 @@ def __init__(self, converter, object_class, rows, raw_headers, operation="import" ) self.converter = converter - self.unique_counts = self.get_unique_counts_dict(self.object_class) + self.unique_values = self.get_unique_values_dict(self.object_class) self.revision_ids = [] + self._import_info = self._make_empty_info() def check_block_restrictions(self): """Check some block related restrictions""" @@ -469,51 +464,31 @@ def row_converters_from_csv(self): """ Generate a row converter object for every csv row """ if self.ignore: return - self.row_converters = [] for i, row in enumerate(self.rows): - row = base_row.ImportRowConverter(self, self.object_class, row=row, + yield base_row.ImportRowConverter(self, self.object_class, row=row, headers=self.headers, index=i) - self.row_converters.append(row) - - def handle_row_data(self, field_list=None): - """Handle row data for all row converters on import. - Note: When field_list is set, we are handling priority columns and we - don't have all the data needed for checking mandatory and duplicate values. + @property + def handle_fields(self): + """Column definitions in correct processing order. - Args: - filed_list (list of strings): list of fields that should be handled by - row converters. This is used only for handling priority columns. + Special columns which are primary keys logically or affect the valid set of + columns go before usual columns. """ - if self.ignore: - return - for row_converter in self.row_converters: - row_converter.handle_row_data(field_list) - if field_list is None: - self.check_mandatory_fields() - self.check_unique_columns() + return [ + k for k in self.headers if k in self.converter.priority_columns + ] + [ + k for k in self.headers if k not in self.converter.priority_columns + ] - def check_mandatory_fields(self): - for row_converter in self.row_converters: - row_converter.check_mandatory_fields() - - def check_unique_columns(self): - self.generate_unique_counts() - for key, counts in self.unique_counts.items(): - self.remove_duplicate_keys(key, counts) + def import_csv_data(self): + """Perform import sequence for the block.""" + for row in self.row_converters_from_csv(): + row.process_row() + self._update_info(row) - def generate_unique_counts(self): - """Populate unique_counts for sent data.""" - for key, header in self.headers.items(): - if not header["unique"]: - continue - for rel_index, row in enumerate(self.row_converters): - value = row.get_value(key) - if value: - self.unique_counts[key][value].add(self._calc_abs_index(rel_index)) - - def get_unique_counts_dict(self, object_class): - """Get the varible to storing unique counts. + def get_unique_values_dict(self, object_class): + """Get the varible to storing row numbers for unique values. Make sure to always return the same variable for object with shared tables, as defined in sharing rules. @@ -522,130 +497,14 @@ def get_unique_counts_dict(self, object_class): classes = sharing_rules.get(object_class, object_class) shared_state = self.converter.shared_state if classes not in shared_state: - shared_state[classes] = defaultdict( - lambda: structures.CaseInsensitiveDefaultDict(set) - ) + shared_state[classes] = defaultdict(structures.CaseInsensitiveDict) return shared_state[classes] - def import_objects(self): - """Add all objects to the database. - - This function flushes all objects to the database if the dry_run flag is - not set and all signals for the imported objects get sent. - """ - if self.ignore: - return - - self._import_objects_prepare() - - if not self.converter.dry_run: - new_objects = [] - for row_converter in self.row_converters: - row_converter.send_pre_commit_signals() - for row_converter in self.row_converters: - try: - row_converter.insert_object() - db.session.flush() - except exc.SQLAlchemyError as err: - db.session.rollback() - logger.exception("Import failed with: %s", err.message) - row_converter.add_error(errors.UNKNOWN_ERROR) - else: - if row_converter.is_new and not row_converter.ignore: - new_objects.append(row_converter.obj) - self.send_collection_post_signals(new_objects) - import_event = self.save_import() - for row_converter in self.row_converters: - row_converter.send_post_commit_signals(event=import_event) - - def _import_objects_prepare(self): - """Setup all objects and do pre-commit checks for them.""" - for row_converter in self.row_converters: - row_converter.setup_object() - - for row_converter in self.row_converters: - self._check_object(row_converter) - - self.clean_session_from_ignored_objs() - - def clean_session_from_ignored_objs(self): - """Clean DB session from ignored objects. - - This function expunges objects from 'db.session' which are in rows that - marked as 'ignored' before commit. - """ - for row_converter in self.row_converters: - obj = row_converter.obj - try: - if row_converter.do_not_expunge: - continue - if row_converter.ignore and obj in db.session: - db.session.expunge(obj) - except UnmappedInstanceError: - continue - - def save_import(self): - """Commit all changes in the session and update memcache.""" - try: - modified_objects = get_modified_objects(db.session) - import_event = log_event(db.session, None) - cache_utils.update_memcache_before_commit( - self, modified_objects, CACHE_EXPIRY_IMPORT) - for row_converter in self.row_converters: - try: - row_converter.send_before_commit_signals(import_event) - except StatusValidationError as exp: - status_alias = row_converter.headers.get("status", - {}).get("display_name") - row_converter.add_error( - errors.VALIDATION_ERROR, - column_name=status_alias, - message=exp.message - ) - db.session.commit() - self._store_revision_ids(import_event) - cache_utils.update_memcache_after_commit(self) - update_snapshot_index(db.session, modified_objects) - return import_event - except exc.SQLAlchemyError as err: - db.session.rollback() - logger.exception("Import failed with: %s", err.message) - self.add_errors(errors.UNKNOWN_ERROR, line=self.offset + 2) - def _store_revision_ids(self, event): """Store revision ids from the current event.""" if event: self.revision_ids.extend(revision.id for revision in event.revisions) - def import_secondary_objects(self): - """Import secondary objects procedure.""" - for row_converter in self.row_converters: - row_converter.setup_secondary_objects() - - if not self.converter.dry_run: - for row_converter in self.row_converters: - try: - row_converter.insert_secondary_objects() - except exc.SQLAlchemyError as err: - db.session.rollback() - logger.exception("Import failed with: %s", err.message) - row_converter.add_error(errors.UNKNOWN_ERROR) - self.save_import() - - @staticmethod - def _check_object(row_converter): - """Check object if it has any pre commit checks. - - The check functions can mutate the row_converter object and mark it - to be ignored if there are any errors detected. - - Args: - row_converter: Row converter for the row we want to check. - """ - checker = pre_commit_checks.CHECKS.get(type(row_converter.obj).__name__) - if checker and callable(checker): - checker(row_converter) - @staticmethod def send_collection_post_signals(new_objects): """Send bulk create pre-commit signals.""" @@ -663,86 +522,45 @@ def send_collection_post_signals(new_objects): def get_info(self): """Returns info dict for current block.""" - created = 0 - updated = 0 - ignored = 0 - deleted = 0 - deprecated = 0 - for row in self.row_converters: - if row.ignore: - ignored += 1 - continue - if row.is_delete: - deleted += 1 - continue - if row.is_new: - created += 1 - else: - updated += 1 - deprecated += int(row.is_deprecated) - info = { - "name": self.name, - "rows": len(self.rows), - "created": created, - "updated": updated, - "ignored": ignored, - "deleted": deleted, - "deprecated": deprecated, + info = self._import_info.copy() + info.update({ "block_warnings": self.block_warnings, "block_errors": self.block_errors, "row_warnings": self.row_warnings, "row_errors": self.row_errors, - } - + }) return info - def _in_range(self, index, remove_offset=True): - """Checks if the value provided lays within the range of lines of the - current block - """ - if remove_offset: - index = self._calc_offset(index) - return index >= 0 and index < len(self.row_converters) + def _make_empty_info(self): + """Empty info dict with all counts zero.""" + return { + "name": self.name, + "rows": 0, + "created": 0, + "updated": 0, + "ignored": 0, + "deleted": 0, + "deprecated": 0, + "block_warnings": [], + "block_errors": [], + "row_warnings": [], + "row_errors": [], + } - def _calc_offset(self, index): - """Calculate an offset relative to the current block beginning - given an absolute line index - """ - return index - self.BLOCK_OFFSET - self.offset + def _update_info(self, row): + """Update counts for info response from row metadata.""" + self._import_info["rows"] += 1 + if row.ignore: + self._import_info["ignored"] += 1 + elif row.is_delete: + self._import_info["deleted"] += 1 + elif row.is_new: + self._import_info["created"] += 1 + else: + self._import_info["updated"] += 1 - def _calc_abs_index(self, rel_index): - """Calculate an absolute line number given a relative index - """ - return rel_index + self.BLOCK_OFFSET + self.offset - - def remove_duplicate_keys(self, key, counts): - - for value, indexes in counts.items(): - if not any(self._in_range(index) for index in indexes): - continue # ignore duplicates in other related code blocks - - indexes = sorted(list(indexes)) - if len(indexes) > 1: - str_indexes = [str(index) for index in indexes] - self.row_errors.append( - errors.DUPLICATE_VALUE_IN_CSV.format( - line_list=", ".join(str_indexes), - column_name=self.headers[key]["display_name"], - s="s" if len(str_indexes) > 2 else "", - value=value, - ignore_lines=", ".join(str_indexes[1:]), - ) - ) - if key == "slug": # mark obj not to be expunged from the session - for index in indexes: - offset_index = self._calc_offset(index) - if self._in_range(offset_index, remove_offset=False): - self.row_converters[offset_index].set_do_not_expunge() - - for index in indexes[1:]: - offset_index = self._calc_offset(index) - if self._in_range(offset_index, remove_offset=False): - self.row_converters[offset_index].set_ignore() + if row.is_deprecated: + self._import_info["deprecated"] += 1 class ExportBlockConverter(BlockConverter): diff --git a/src/ggrc/converters/base_row.py b/src/ggrc/converters/base_row.py index c0fc4fc040c0..d5a505e01652 100644 --- a/src/ggrc/converters/base_row.py +++ b/src/ggrc/converters/base_row.py @@ -5,17 +5,32 @@ """ import collections +from logging import getLogger + +from sqlalchemy import exc +from sqlalchemy.orm.exc import UnmappedInstanceError import ggrc.services from ggrc import db from ggrc.converters import errors from ggrc.converters import get_importables +from ggrc.converters import pre_commit_checks from ggrc.login import get_current_user_id -from ggrc.models.reflection import AttributeInfo +from ggrc.models import all_models +from ggrc.models.exceptions import StatusValidationError from ggrc.rbac import permissions from ggrc.services import signals +from ggrc.snapshotter import create_snapshots from ggrc.utils import dump_attrs +from ggrc.models.reflection import AttributeInfo +from ggrc.services.common import get_modified_objects +from ggrc.services.common import update_snapshot_index +from ggrc.cache import utils as cache_utils +from ggrc.utils.log_event import log_event + +logger = getLogger(__name__) + class RowConverter(object): """Base class for handling row data.""" @@ -40,13 +55,74 @@ def __init__(self, block_converter, object_class, headers, index, **options): self.is_new = True self.is_delete = False self.is_deprecated = False - self.do_not_expunge = False self.ignore = False self.row = options.get("row", []) self.id_key = "" self.line = self.index + self.block_converter.offset + \ self.block_converter.BLOCK_OFFSET self.initial_state = None + self.is_new_object_set = False + + def handle_raw_cell(self, attr_name, idx, header_dict): + """Process raw value from self.row[idx] for attr_name. + + This function finds and instantiates the correct handler class and handles + special logic for deprecated status and primary key attributes, as well as + value uniqueness. + """ + handler = header_dict["handler"] + item = handler(self, attr_name, parse=True, + raw_value=self.row[idx], **header_dict) + if header_dict.get("type") == AttributeInfo.Type.PROPERTY: + self.attrs[attr_name] = item + else: + self.objects[attr_name] = item + if attr_name == "status" and hasattr(self.obj, "DEPRECATED"): + self.is_deprecated = ( + self.obj.DEPRECATED == item.value != self.obj.status + ) + if attr_name in ("slug", "email"): + self.id_key = attr_name + self.obj = self.get_or_generate_object(attr_name) + item.set_obj_attr() + if header_dict["unique"]: + value = self.get_value(attr_name) + if value: + unique_values = self.block_converter.unique_values + if unique_values[attr_name].get(value) is not None: + self.add_error( + errors.DUPLICATE_VALUE_IN_CSV.format( + line=self.line, + processed_line=unique_values[attr_name][value], + column_name=header_dict["display_name"], + value=value, + ), + ) + item.is_duplicate = True + else: + self.block_converter.unique_values[attr_name][value] = self.line + item.check_unique_consistency() + + def handle_raw_data(self): + """Pass raw values into column handlers for all cell in the row.""" + row_headers = {attr_name: (idx, header_dict) + for idx, (attr_name, header_dict) + in enumerate(self.headers.iteritems())} + for attr_name in self.block_converter.handle_fields: + if attr_name not in row_headers or self.is_delete: + continue + idx, header_dict = row_headers[attr_name] + self.handle_raw_cell(attr_name, idx, header_dict) + + def update_new_obj_cache(self): + if not self.is_new or not getattr(self.obj, self.id_key): + return + self.is_new_object_set = True + self.block_converter.converter.new_objects[ + self.obj.__class__ + ][ + getattr(self.obj, self.id_key) + ] = self.obj def add_error(self, template, **kwargs): """Add error for current row. @@ -60,46 +136,20 @@ def add_error(self, template, **kwargs): """ message = template.format(line=self.line, **kwargs) self.block_converter.row_errors.append(message) - new_objects = self.block_converter.converter.new_objects[self.object_class] - key = self.get_value(self.id_key) - if key in new_objects: - del new_objects[key] + if self.is_new_object_set: + new_objects = self.block_converter.converter.new_objects[ + self.object_class + ] + key = self.get_value(self.id_key) + if key in new_objects: + del new_objects[key] + self.is_new_object_set = False self.ignore = True def add_warning(self, template, **kwargs): message = template.format(line=self.line, **kwargs) self.block_converter.row_warnings.append(message) - def handle_csv_row_data(self, field_list=None): - """ Pack row data with handlers """ - handle_fields = self.headers if field_list is None else field_list - for i, (attr_name, header_dict) in enumerate(self.headers.items()): - if attr_name not in handle_fields or \ - attr_name in self.attrs or \ - self.is_delete: - continue - handler = header_dict["handler"] - item = handler(self, attr_name, parse=True, - raw_value=self.row[i], **header_dict) - if header_dict.get("type") == AttributeInfo.Type.PROPERTY: - self.attrs[attr_name] = item - else: - self.objects[attr_name] = item - if not self.ignore: - if attr_name == "status" and hasattr(self.obj, "DEPRECATED"): - self.is_deprecated = ( - self.obj.DEPRECATED == item.value != self.obj.status - ) - if attr_name in ("slug", "email"): - self.id_key = attr_name - self.obj = self.get_or_generate_object(attr_name) - item.set_obj_attr() - item.check_unique_consistency() - - def handle_row_data(self, field_list=None): - """Handle row data on import""" - self.handle_csv_row_data(field_list) - def check_mandatory_fields(self): """Check if the new object contains all mandatory columns.""" if not self.is_new or self.is_delete or self.ignore: @@ -128,14 +178,6 @@ def get_value(self, key): def set_ignore(self, ignore=True): self.ignore = ignore - def set_do_not_expunge(self, do_not_expunge=True): - """Mark an ignored object not to be expunged from a session - - We may not expunge objects with duplicate slugs, because they represent - the same object. - """ - self.do_not_expunge = do_not_expunge - def get_or_generate_object(self, attr_name): """Fetch an existing object if possible or create and return a new one. @@ -185,6 +227,107 @@ def setup_secondary_objects(self): return for mapping in self.objects.values(): mapping.set_obj_attr() + if self.block_converter.converter.dry_run: + return + try: + self.insert_secondary_objects() + except exc.SQLAlchemyError as err: + db.session.rollback() + logger.exception("Import failed with: %s", err.message) + self.add_error(errors.UNKNOWN_ERROR) + + def process_row(self): + """Parse, set, validate and commit data specified in self.row.""" + self.handle_raw_data() + self.check_mandatory_fields() + if self.ignore: + return + self.update_new_obj_cache() + self.setup_object() + self.check_object() + try: + if self.ignore and self.obj in db.session: + db.session.expunge(self.obj) + except UnmappedInstanceError: + return + if self.block_converter.ignore: + return + self.flush_object() + self.setup_secondary_objects() + self.commit_object() + + def check_object(self): + """Check object if it has any pre commit checks. + + The check functions can mutate the row_converter object and mark it + to be ignored if there are any errors detected. + + Args: + row_converter: Row converter for the row we want to check. + """ + checker = pre_commit_checks.CHECKS.get(type(self.obj).__name__) + if checker and callable(checker): + checker(self) + + def flush_object(self): + """Flush dirty data related to the current row.""" + if self.block_converter.converter.dry_run or self.ignore: + return + self.send_pre_commit_signals() + try: + if self.object_class == all_models.Audit and self.is_new: + # This hack is needed only for snapshot creation + # for audit during import, this is really bad and + # need to be refactored + import_event = log_event(db.session, None) + self.insert_object() + db.session.flush() + if self.object_class == all_models.Audit and self.is_new: + # This hack is needed only for snapshot creation + # for audit during import, this is really bad and + # need to be refactored + create_snapshots(self.obj, import_event) + except exc.SQLAlchemyError as err: + db.session.rollback() + logger.exception("Import failed with: %s", err.message) + self.add_error(errors.UNKNOWN_ERROR) + return + if self.is_new and not self.ignore: + self.block_converter.send_collection_post_signals([self.obj]) + + def commit_object(self): + """Commit the row. + + This method also calls pre-and post-commit signals and handles failures. + """ + if self.block_converter.converter.dry_run or self.ignore: + return + try: + modified_objects = get_modified_objects(db.session) + import_event = log_event(db.session, None) + cache_utils.update_memcache_before_commit( + self.block_converter, + modified_objects, + self.block_converter.CACHE_EXPIRY_IMPORT, + ) + try: + self.send_before_commit_signals(import_event) + except StatusValidationError as exp: + status_alias = self.headers.get("status", {}).get("display_name") + self.add_error(errors.VALIDATION_ERROR, + column_name=status_alias, + message=exp.message) + db.session.commit() + self.block_converter._store_revision_ids(import_event) + cache_utils.update_memcache_after_commit(self.block_converter) + update_snapshot_index(db.session, modified_objects) + except exc.SQLAlchemyError as err: + db.session.rollback() + logger.exception("Import failed with: %s", err.message) + self.block_converter.add_errors(errors.UNKNOWN_ERROR, + line=self.offset + 2) + else: + self.send_post_commit_signals(event=import_event) def setup_object(self): """ Set the object values or relate object values diff --git a/src/ggrc/converters/errors.py b/src/ggrc/converters/errors.py index 397db85761a4..d00056510df2 100644 --- a/src/ggrc/converters/errors.py +++ b/src/ggrc/converters/errors.py @@ -34,9 +34,8 @@ u"Duplicates: {duplicates}" ) -DUPLICATE_VALUE_IN_CSV = (u"Lines {line_list} have same {column_name}" - u" '{value}'. Line{s} {ignore_lines} will be" - u" ignored.") +DUPLICATE_VALUE_IN_CSV = (u"Line {line} has the same {column_name} '{value}' " + u"as {processed_line}. The line will be ignored.") MAP_UNMAP_CONFLICT = (u"Line {line}: Object '{slug}' scheduled for mapping and" u" unmapping at the same time. Mapping rule update will" diff --git a/src/ggrc/converters/handlers/handlers.py b/src/ggrc/converters/handlers/handlers.py index 0191c9d855d7..61943f7a557a 100644 --- a/src/ggrc/converters/handlers/handlers.py +++ b/src/ggrc/converters/handlers/handlers.py @@ -51,6 +51,7 @@ def __init__(self, row_converter, key, **options): self.key = key self.value = None self.set_empty = False + self.is_duplicate = False self.raw_value = options.get("raw_value", "").strip() self.validator = options.get("validator") self.mandatory = options.get("mandatory", False) @@ -72,6 +73,9 @@ def check_unique_consistency(self): return if not self.row_converter.obj: return + if self.is_duplicate: + # a hack to avoid two different errors for the same non-unique cell + return nr_duplicates = self.row_converter.object_class.query.filter(and_( getattr(self.row_converter.object_class, self.key) == self.value, self.row_converter.object_class.id != self.row_converter.obj.id diff --git a/src/ggrc/converters/import_helper.py b/src/ggrc/converters/import_helper.py index 78c8364fc656..c0f21a5f6aed 100644 --- a/src/ggrc/converters/import_helper.py +++ b/src/ggrc/converters/import_helper.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -def get_object_column_definitions(object_class): +def get_object_column_definitions(object_class, fields=None): """Attach additional info to attribute definitions. Fetches the attribute info (_aliases) for the given object class and adds @@ -35,7 +35,8 @@ def get_object_column_definitions(object_class): Returns: dict: Updated attribute definitions dict with additional data. """ - attributes = AttributeInfo.get_object_attr_definitions(object_class) + attributes = AttributeInfo.get_object_attr_definitions(object_class, + fields=fields) column_handlers = model_column_handlers(object_class) for key, attr in attributes.iteritems(): handler_key = attr.get("handler_key", key) diff --git a/src/ggrc/converters/snapshot_block.py b/src/ggrc/converters/snapshot_block.py index 426e987d87b0..4c7863a5969b 100644 --- a/src/ggrc/converters/snapshot_block.py +++ b/src/ggrc/converters/snapshot_block.py @@ -63,10 +63,6 @@ def __init__(self, converter, ids, fields=None): self.ids = ids self.fields = fields or [] - @staticmethod - def handle_row_data(): - pass - @property def name(self): return "{} Snapshot".format(self.child_type) diff --git a/src/ggrc/models/mixins/customattributable.py b/src/ggrc/models/mixins/customattributable.py index 4294f443fee6..23684d83c558 100644 --- a/src/ggrc/models/mixins/customattributable.py +++ b/src/ggrc/models/mixins/customattributable.py @@ -6,6 +6,7 @@ import collections from logging import getLogger +import sqlalchemy from sqlalchemy import and_ from sqlalchemy import orm from sqlalchemy import or_ @@ -406,7 +407,7 @@ def custom_attributes(self, src): }, service=self.__class__.__name__) @classmethod - def get_custom_attribute_definitions(cls): + def get_custom_attribute_definitions(cls, field_names=None): """Get all applicable CA definitions (even ones without a value yet).""" from ggrc.models.custom_attribute_definition import \ CustomAttributeDefinition as cad @@ -420,6 +421,10 @@ def get_custom_attribute_definitions(cls): query = cad.query.filter( cad.definition_type == utils.underscore_from_camelcase(cls.__name__) ) + if field_names: + query = query.filter( + sqlalchemy.or_(cad.title.in_(field_names), cad.mandatory) + ) return query.options( orm.undefer_group('CustomAttributeDefinition_complete') ) diff --git a/src/ggrc/models/reflection.py b/src/ggrc/models/reflection.py index ddc03c43c825..6c5fb2e98871 100644 --- a/src/ggrc/models/reflection.py +++ b/src/ggrc/models/reflection.py @@ -322,7 +322,8 @@ def get_mapping_definitions(cls, object_class): return definitions @classmethod - def get_custom_attr_definitions(cls, object_class, ca_cache=None): + def get_custom_attr_definitions(cls, object_class, + ca_cache=None, fields=None): """Get column definitions for custom attributes on object_class. Args: @@ -341,7 +342,7 @@ def get_custom_attr_definitions(cls, object_class, ca_cache=None): if isinstance(ca_cache, dict) and object_name: custom_attributes = ca_cache.get(object_name, []) else: - custom_attributes = object_class.get_custom_attribute_definitions() + custom_attributes = object_class.get_custom_attribute_definitions(fields) for attr in custom_attributes: description = attr.helptext or u"" if (attr.attribute_type == attr.ValidTypes.DROPDOWN and @@ -361,16 +362,16 @@ def get_custom_attr_definitions(cls, object_class, ca_cache=None): definition_ids = definitions.get(attr_name, {}).get("definition_ids", []) definition_ids.append(attr.id) - - definitions[attr_name] = { - "display_name": attr.title, - "attr_name": attr.title, - "mandatory": attr.mandatory, - "unique": False, - "description": description, - "type": ca_type, - "definition_ids": definition_ids, - } + if fields is None or attr.title.lower() in fields: + definitions[attr_name] = { + "display_name": attr.title, + "attr_name": attr.title, + "mandatory": attr.mandatory, + "unique": False, + "description": description, + "type": ca_type, + "definition_ids": definition_ids, + } return definitions @classmethod @@ -383,7 +384,8 @@ def get_unique_constraints(cls, object_class): return set(sum(unique_columns, [])) @classmethod - def get_object_attr_definitions(cls, object_class, ca_cache=None): + def get_object_attr_definitions(cls, object_class, + ca_cache=None, fields=None): """Get all column definitions for object_class. This function joins custom attribute definitions, mapping definitions and @@ -428,7 +430,7 @@ def get_object_attr_definitions(cls, object_class, ca_cache=None): if object_class.__name__ not in EXCLUDE_CUSTOM_ATTRIBUTES: definitions.update(cls.get_custom_attr_definitions( - object_class, ca_cache=ca_cache + object_class, ca_cache=ca_cache, fields=fields )) if object_class.__name__ not in EXCLUDE_MAPPINGS: diff --git a/test/integration/ggrc/converters/test_acl_import_export.py b/test/integration/ggrc/converters/test_acl_import_export.py index d3628873566f..0604dc687f44 100644 --- a/test/integration/ggrc/converters/test_acl_import_export.py +++ b/test/integration/ggrc/converters/test_acl_import_export.py @@ -325,8 +325,7 @@ def test_acl_revision_on_import(self): market_revisions = models.Revision.query.filter_by( resource_type="Market" ).count() - # One revision for created object and one for modified when acl was added - self.assertEqual(market_revisions, 2) + self.assertEqual(market_revisions, 1) acr_revisions = models.Revision.query.filter_by( resource_type="AccessControlRole" diff --git a/test/integration/ggrc/converters/test_import_assessments.py b/test/integration/ggrc/converters/test_import_assessments.py index 23bdca6ed57e..c5b21688bb3f 100644 --- a/test/integration/ggrc/converters/test_import_assessments.py +++ b/test/integration/ggrc/converters/test_import_assessments.py @@ -317,11 +317,10 @@ def test_assessment_warnings_errors(self): column_name="Audit" ), errors.DUPLICATE_VALUE_IN_CSV.format( - line_list="20, 22", + line="22", + processed_line="20", column_name="Code", value="Assessment 22", - s="", - ignore_lines="22", ), }, "row_warnings": { diff --git a/test/integration/ggrc/converters/test_import_comprehensive.py b/test/integration/ggrc/converters/test_import_comprehensive.py index 1db9b63bf493..24ae3d8841c7 100644 --- a/test/integration/ggrc/converters/test_import_comprehensive.py +++ b/test/integration/ggrc/converters/test_import_comprehensive.py @@ -50,7 +50,7 @@ def test_comprehensive_with_ca(self): "Objective": { "created": 8, "ignored": 7, - "row_errors": 5, + "row_errors": 7, "row_warnings": 4, "rows": 15, }, @@ -310,18 +310,16 @@ def test_case_sensitive_slugs(self): "Control": { "row_errors": { errors.DUPLICATE_VALUE_IN_CSV.format( - line_list="3, 4", + line="4", + processed_line="3", column_name="Code", - s="", - value="a", - ignore_lines="4", + value="A", ), errors.DUPLICATE_VALUE_IN_CSV.format( - line_list="3, 4", + line="4", + processed_line="3", column_name="Title", - s="", - value="a", - ignore_lines="4", + value="A", ), } } diff --git a/test/integration/ggrc/converters/test_import_csv.py b/test/integration/ggrc/converters/test_import_csv.py index dbf0bdcb0cca..1d84f4d8f360 100644 --- a/test/integration/ggrc/converters/test_import_csv.py +++ b/test/integration/ggrc/converters/test_import_csv.py @@ -4,6 +4,8 @@ from collections import OrderedDict +import ddt + from ggrc import models from ggrc.converters import errors from integration.ggrc import TestCase @@ -11,6 +13,7 @@ from integration.ggrc.models import factories +@ddt.ddt class TestBasicCsvImport(TestCase): """Test basic CSV imports.""" @@ -38,7 +41,7 @@ def test_policy_basic_import(self): revisions = models.Revision.query.filter( models.Revision.resource_type == "Policy" ).count() - self.assertEqual(revisions, 6) + self.assertEqual(revisions, 3) policy = models.Policy.eager_query().first() self.assertEqual(policy.modified_by.email, "user@example.com") @@ -97,14 +100,26 @@ def test_owners(policy): expected_errors = { errors.DUPLICATE_VALUE_IN_CSV.format( - line_list="3, 4, 6, 10, 11", column_name="Title", - value="A title", s="s", ignore_lines="4, 6, 10, 11"), + line="4", processed_line="3", column_name="Title", + value="A title"), + errors.DUPLICATE_VALUE_IN_CSV.format( + line="6", processed_line="3", column_name="Title", + value="A title"), + errors.DUPLICATE_VALUE_IN_CSV.format( + line="10", processed_line="3", column_name="Title", + value="A title"), + errors.DUPLICATE_VALUE_IN_CSV.format( + line="11", processed_line="3", column_name="Title", + value="A title"), + errors.DUPLICATE_VALUE_IN_CSV.format( + line="7", processed_line="5", column_name="Title", + value="A different title"), + errors.DUPLICATE_VALUE_IN_CSV.format( + line="9", processed_line="8", column_name="Code", value="code"), errors.DUPLICATE_VALUE_IN_CSV.format( - line_list="5, 7", column_name="Title", value="A different title", - s="", ignore_lines="7"), + line="10", processed_line="8", column_name="Code", value="code"), errors.DUPLICATE_VALUE_IN_CSV.format( - line_list="8, 9, 10, 11", column_name="Code", value="code", - s="s", ignore_lines="9, 10, 11"), + line="11", processed_line="8", column_name="Code", value="code"), } response_errors = response_json[0]["row_errors"] self.assertEqual(expected_errors, set(response_errors)) @@ -114,18 +129,82 @@ def test_owners(policy): for policy in policies: test_owners(policy) - def test_intermappings(self): + @ddt.data(True, False) + def test_intermappings(self, reverse_order): + """It is allowed to reference previous lines in map columns.""" self.generate_people(["miha", "predrag", "vladan", "ivan"]) - - filename = "intermappings.csv" - response_json = self.import_file(filename) - + facility_data_block = [ + OrderedDict([ + ("object_type", "facility"), + ("Code*", "HOUSE-{}".format(idx)), + ("title", "Facility-{}".format(idx)), + ("admin", "user@example.com"), + ("map:facility", "" if idx == 1 else "HOUSE-{}".format(idx - 1)), + ]) + for idx in (xrange(1, 5) if not reverse_order else xrange(4, 0, -1)) + ] + objective_data_block = [ + OrderedDict([ + ("object_type", "objective"), + ("Code*", "O1"), + ("title", "House of cards"), + ("admin", "user@example.com"), + ("map:facility", "HOUSE-2"), + ("map:objective", ""), + ]), + OrderedDict([ + ("object_type", "objective"), + ("Code*", "O2"), + ("title", "House of the rising sun"), + ("admin", "user@example.com"), + ("map:facility", "HOUSE-3"), + ("map:objective", "O1\nO2\nO3"), + ]), + OrderedDict([ + ("object_type", "objective"), + ("Code*", "O3"), + ("title", "Yellow house"), + ("admin", "user@example.com"), + ("map:facility", "HOUSE-4"), + ("map:objective", ""), + ]), + OrderedDict([ + ("object_type", "objective"), + ("Code*", "O4"), + ("title", "There is no place like home"), + ("admin", "user@example.com"), + ("map:facility", "HOUSE-1"), + ("map:objective", "O3\nO4\nO3"), + ]), + ] + + response_json = self.import_data( + *(facility_data_block + objective_data_block) + ) self.assertEqual(4, response_json[0]["created"]) # Facility self.assertEqual(4, response_json[1]["created"]) # Objective - response_warnings = response_json[0]["row_warnings"] - self.assertEqual(set(), set(response_warnings)) - self.assertEqual(13, models.Relationship.query.count()) + if reverse_order: + expected_block_1 = set([ + u"Line {line}: Facility 'house-{idx}' " + u"doesn't exist, so it can't be mapped/unmapped.".format( + idx=idx, + line=6 - idx + ) + for idx in xrange(1, 4) + ]) + rel_numbers = 8 + else: + expected_block_1 = set() + rel_numbers = 11 + + expected_block_2 = { + u"Line 11: Objective 'o3' doesn't exist, " + u"so it can't be mapped/unmapped." + } + self.assertEqual(expected_block_1, set(response_json[0]["row_warnings"])) + self.assertEqual(expected_block_2, set(response_json[1]["row_warnings"])) + self.assertEqual(rel_numbers, models.Relationship.query.count()) obj2 = models.Objective.query.filter_by(slug="O2").first() obj3 = models.Objective.query.filter_by(slug="O3").first() diff --git a/test/integration/ggrc/converters/test_import_update.py b/test/integration/ggrc/converters/test_import_update.py index c2b8be2f74e7..4d03f7e363f9 100644 --- a/test/integration/ggrc/converters/test_import_update.py +++ b/test/integration/ggrc/converters/test_import_update.py @@ -26,7 +26,7 @@ def test_policy_basic_update(self): models.Revision.resource_type == "Policy", models.Revision.resource_id == policy.id ).count() - self.assertEqual(revision_count, 2) + self.assertEqual(revision_count, 1) self.import_file("policy_basic_import_update.csv") @@ -36,7 +36,7 @@ def test_policy_basic_update(self): models.Revision.resource_type == "Policy", models.Revision.resource_id == policy.id ).count() - self.assertEqual(revision_count, 4) + self.assertEqual(revision_count, 2) self.assertEqual( policy.access_control_list[0].person.email, "user1@example.com" diff --git a/test/integration/ggrc/converters/test_multi_import_csv.py b/test/integration/ggrc/converters/test_multi_import_csv.py index 7bfb2cee663b..a054e416e4bd 100644 --- a/test/integration/ggrc/converters/test_multi_import_csv.py +++ b/test/integration/ggrc/converters/test_multi_import_csv.py @@ -72,17 +72,21 @@ def test_multi_basic_policy_orggroup_product_with_warnings(self): expected_warnings = set([ errors.DUPLICATE_VALUE_IN_CSV.format( - line_list="5, 6", column_name="Title", value="dolor", - s="", ignore_lines="6"), + line="6", processed_line="5", column_name="Title", value="dolor", + ), errors.DUPLICATE_VALUE_IN_CSV.format( - line_list="6, 7", column_name="Code", value="p-4", - s="", ignore_lines="7"), + line="7", processed_line="6", column_name="Code", value="p-4", + ), errors.DUPLICATE_VALUE_IN_CSV.format( - line_list="21, 26", column_name="Title", value="meatloaf", - s="", ignore_lines="26"), + line="26", processed_line="21", column_name="Title", + value="meatloaf", + ), errors.DUPLICATE_VALUE_IN_CSV.format( - line_list="21, 26, 27", column_name="Code", value="pro 1", - s="s", ignore_lines="26, 27"), + line="26", processed_line="21", column_name="Code", value="pro 1", + ), + errors.DUPLICATE_VALUE_IN_CSV.format( + line="27", processed_line="21", column_name="Code", value="pro 1", + ), errors.OWNER_MISSING.format(line=26, column_name="Admin"), errors.MISSING_COLUMN.format(line=13, column_names="Admin", s=""), errors.MISSING_COLUMN.format(line=14, column_names="Admin", s=""), diff --git a/test/integration/ggrc/services/resources/test_converters.py b/test/integration/ggrc/services/resources/test_converters.py index 59accec1d116..24b4b0631b97 100644 --- a/test/integration/ggrc/services/resources/test_converters.py +++ b/test/integration/ggrc/services/resources/test_converters.py @@ -357,7 +357,7 @@ def test_import_control_revisions(self): all_models.Revision.resource_type == "Control", all_models.Revision.resource_id == control.id ) - self.assertEqual({"created", "modified"}, {a[0] for a in revision_actions}) + self.assertEqual({"created"}, {a[0] for a in revision_actions}) @mock.patch( "ggrc.gdrive.file_actions.get_gdrive_file_data", diff --git a/test/integration/ggrc/test_csvs/intermappings.csv b/test/integration/ggrc/test_csvs/intermappings.csv deleted file mode 100644 index 486cabf8a259..000000000000 --- a/test/integration/ggrc/test_csvs/intermappings.csv +++ /dev/null @@ -1,18 +0,0 @@ -Object type,,,,, -facility,code,title,admin,map:facility, -,HOUSE-1,House of cards,user@example.com,HOUSE-2, -,HOUSE-2,House of the rising sun,user@example.com,HOUSE-3, -,HOUSE-3,Yellow house,user@example.com,HOUSE-4, -,HOUSE-4,There is no place like home,user@example.com,HOUSE-1, -,,,,, -,,,,, -Object type,,,,, -Objective,code,title,admin,map:facility,map:objective -,O1,House of cards,user@example.com,HOUSE-2, -,O2,House of the rising sun,user@example.com,HOUSE-3,"O1 -O2 -O3" -,O3,Yellow house,user@example.com,HOUSE-4, -,O4,There is no place like home,user@example.com,HOUSE-1,"O3 -O4 -O3" \ No newline at end of file diff --git a/test/integration/ggrc_basic_permissions/test_audit_archiving.py b/test/integration/ggrc_basic_permissions/test_audit_archiving.py index 56f2eb436d8b..f1fe91f92f35 100644 --- a/test/integration/ggrc_basic_permissions/test_audit_archiving.py +++ b/test/integration/ggrc_basic_permissions/test_audit_archiving.py @@ -355,12 +355,20 @@ def test_audit_context_editing(self, person, status, objects): def test_audit_snapshot_editing(self, person, status, obj): """Test if {0} can edit objects in the audit context: {1} - {2}""" self.api.set_user(self.people[person]) - obj_instance = getattr(self, obj) + obj_instance_id = getattr(self, obj).id + snapshot = all_models.Snapshot.query.get(obj_instance_id) + # update obj to create new revision + self.api.put( + all_models.Objective.query.get(snapshot.revision.resource_id), + { + "status": "Active", + } + ) json = { "update_revision": "latest" } - response = self.api.put(obj_instance, json) + response = self.api.put(snapshot, json) assert response.status_code == status, \ "{} put returned {} instead of {} for {}".format( person, response.status, status, obj)