Skip to content

Commit

Permalink
Adding Springs AuditorAware capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
agordon-vivid committed Dec 12, 2024
1 parent 2844251 commit e7be0a2
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package ca.bc.gov.nrs.wfprev;

import org.springframework.data.domain.AuditorAware;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.DefaultOAuth2AuthenticatedPrincipal;
import org.springframework.stereotype.Component;

import java.util.Optional;

@Component
public class SpringSecurityAuditorAware implements AuditorAware<String> {

@Override
public Optional<String> getCurrentAuditor() {
return Optional.ofNullable(SecurityContextHolder.getContext())
.map(context -> context.getAuthentication())
.filter(Authentication::isAuthenticated)
.map(authentication -> {
Object principal = authentication.getPrincipal();
if (principal instanceof DefaultOAuth2AuthenticatedPrincipal) {
// Extract username or preferred identifier
DefaultOAuth2AuthenticatedPrincipal oauthPrincipal = (DefaultOAuth2AuthenticatedPrincipal) principal;
return (String) oauthPrincipal.getAttribute("preferred_username"); // Adjust key to match your provider
}
throw new IllegalStateException("Principal is not of type DefaultOAuth2AuthenticatedPrincipal");
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.core.Ordered;
import org.springframework.data.jpa.repository.config.EnableJpaAuditing;
import org.springframework.web.filter.ForwardedHeaderFilter;

import com.fasterxml.jackson.annotation.JsonFormat;
Expand All @@ -19,6 +20,7 @@
import jakarta.servlet.DispatcherType;

@SpringBootApplication
@EnableJpaAuditing(auditorAwareRef = "springSecurityAuditorAware")
public class WfprevApiApplication {
/*
* Run the application as a JAR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EntityListeners;
import jakarta.persistence.FetchType;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
Expand All @@ -24,13 +25,14 @@
import org.springframework.data.annotation.CreatedDate;
import org.springframework.data.annotation.LastModifiedBy;
import org.springframework.data.annotation.LastModifiedDate;
import org.springframework.data.jpa.domain.support.AuditingEntityListener;

import java.io.Serializable;
import java.math.BigDecimal;
import java.util.Date;
import java.util.UUID;

@Entity
@EntityListeners(AuditingEntityListener.class)
@Table(name = "project")
@JsonIgnoreProperties(ignoreUnknown = false)
@Data
Expand All @@ -39,6 +41,7 @@
@AllArgsConstructor
@NoArgsConstructor
public class ProjectEntity implements Serializable {

@Id
@UuidGenerator
@GeneratedValue(strategy = GenerationType.UUID)
Expand All @@ -49,7 +52,7 @@ public class ProjectEntity implements Serializable {
@JoinColumn(name = "project_type_code")
private ProjectTypeCodeEntity projectTypeCode;

@Column(name = "project_number", columnDefinition="Decimal(10)", insertable = false, updatable = true)
@Column(name = "project_number", columnDefinition = "Decimal(10)", insertable = false, updatable = true)
private Integer projectNumber;

@NotNull
Expand Down Expand Up @@ -81,7 +84,7 @@ public class ProjectEntity implements Serializable {
@Column(name = "bc_parks_region_org_unit_id", columnDefinition = "Decimal(10)")
private Integer bcParksRegionOrgUnitId;

@Column(name = "bcParksSectionOrgUnitId", columnDefinition = "Decimal(10)")
@Column(name = "bc_parks_section_org_unit_id", columnDefinition = "Decimal(10)")
private Integer bcParksSectionOrgUnitId;

@NotNull
Expand Down Expand Up @@ -144,7 +147,7 @@ public class ProjectEntity implements Serializable {
private Integer revisionCount;

@ManyToOne(fetch = FetchType.EAGER, optional = false)
@JoinColumn(name="project_status_code")
@JoinColumn(name = "project_status_code")
private ProjectStatusCodeEntity projectStatusCode;

@CreatedBy
Expand All @@ -166,4 +169,4 @@ public class ProjectEntity implements Serializable {
@NotNull
@Column(name = "update_date")
private Date updateDate;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,9 @@ public ProjectModel createOrUpdateProject(ProjectModel resource) throws ServiceE
try {
if (resource.getProjectGuid() == null) {
resource.setCreateDate(new Date());
//TODO - Fix to use proper user
resource.setCreateUser("SYSTEM");
resource.setProjectGuid(UUID.randomUUID().toString());
resource.setCreateDate(new Date());
resource.setCreateUser("SYSTEM");
resource.setRevisionCount(0); // Initialize revision count for new records
}
// Set audit fields
resource.setUpdateDate(new Date());
//TODO - Fix to use proper user
resource.setUpdateUser("SYSTEM");

ProjectEntity entity = projectResourceAssembler.toEntity(resource);

// Load ForestAreaCode with null checks
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package ca.bc.gov.nrs.wfprev;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.DefaultOAuth2AuthenticatedPrincipal;

import java.util.Optional;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

class SpringSecurityAuditorAwareTest {

private SpringSecurityAuditorAware auditorAware;

private MockedStatic<SecurityContextHolder> securityContextHolderMock;

@BeforeEach
void setUp() {
auditorAware = new SpringSecurityAuditorAware();
securityContextHolderMock = Mockito.mockStatic(SecurityContextHolder.class);
}

@AfterEach
void tearDown() {
securityContextHolderMock.close();
}

@Test
void getCurrentAuditor_authenticatedUser_returnsUsername() {
// Given: A valid SecurityContext with an authenticated user
SecurityContext mockSecurityContext = mock(SecurityContext.class);
Authentication mockAuthentication = mock(Authentication.class);
DefaultOAuth2AuthenticatedPrincipal mockPrincipal = mock(DefaultOAuth2AuthenticatedPrincipal.class);

when(SecurityContextHolder.getContext()).thenReturn(mockSecurityContext);
when(mockSecurityContext.getAuthentication()).thenReturn(mockAuthentication);
when(mockAuthentication.isAuthenticated()).thenReturn(true);
when(mockAuthentication.getPrincipal()).thenReturn(mockPrincipal);
when(mockPrincipal.getAttribute("preferred_username")).thenReturn("test_user");

// When: getCurrentAuditor is called
Optional<String> result = auditorAware.getCurrentAuditor();

// Then: The correct username is returned
assertEquals(Optional.of("test_user"), result);
}

@Test
void getCurrentAuditor_userNotAuthenticated_returnsEmptyOptional() {
// Given: A valid SecurityContext with an unauthenticated user
SecurityContext mockSecurityContext = mock(SecurityContext.class);
Authentication mockAuthentication = mock(Authentication.class);

when(SecurityContextHolder.getContext()).thenReturn(mockSecurityContext);
when(mockSecurityContext.getAuthentication()).thenReturn(mockAuthentication);
when(mockAuthentication.isAuthenticated()).thenReturn(false);

// When: getCurrentAuditor is called
Optional<String> result = auditorAware.getCurrentAuditor();

// Then: An empty Optional is returned
assertEquals(Optional.empty(), result);
}

@Test
void getCurrentAuditor_invalidPrincipalType_throwsIllegalStateException() {
// Given: A valid SecurityContext with an invalid principal type
SecurityContext mockSecurityContext = mock(SecurityContext.class);
Authentication mockAuthentication = mock(Authentication.class);

when(SecurityContextHolder.getContext()).thenReturn(mockSecurityContext);
when(mockSecurityContext.getAuthentication()).thenReturn(mockAuthentication);
when(mockAuthentication.isAuthenticated()).thenReturn(true);
when(mockAuthentication.getPrincipal()).thenReturn("InvalidPrincipal");

// When & Then: getCurrentAuditor throws an IllegalStateException
IllegalStateException exception = assertThrows(IllegalStateException.class, auditorAware::getCurrentAuditor);
assertEquals("Principal is not of type DefaultOAuth2AuthenticatedPrincipal", exception.getMessage());
}

@Test
void getCurrentAuditor_nullSecurityContext_returnsEmptyOptional() {
// Given: A null SecurityContext
when(SecurityContextHolder.getContext()).thenReturn(null);

// When: getCurrentAuditor is called
Optional<String> result = auditorAware.getCurrentAuditor();

// Then: An empty Optional is returned
assertEquals(Optional.empty(), result);
}

@Test
void getCurrentAuditor_nullAuthentication_returnsEmptyOptional() {
// Given: A valid SecurityContext with a null Authentication
SecurityContext mockSecurityContext = mock(SecurityContext.class);

when(SecurityContextHolder.getContext()).thenReturn(mockSecurityContext);
when(mockSecurityContext.getAuthentication()).thenReturn(null);

// When: getCurrentAuditor is called
Optional<String> result = auditorAware.getCurrentAuditor();

// Then: An empty Optional is returned
assertEquals(Optional.empty(), result);
}
}

0 comments on commit e7be0a2

Please sign in to comment.