Skip to content

Commit

Permalink
ES-1986 (#1078)
Browse files Browse the repository at this point in the history
* ES-1986

Signed-off-by: ase-101 <[email protected]>

* ES-1986

Signed-off-by: ase-101 <[email protected]>

---------

Signed-off-by: ase-101 <[email protected]>
  • Loading branch information
ase-101 authored Jan 7, 2025
1 parent 0448094 commit 44948ac
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,13 @@ public interface TokenService {

/**
* Creates ID token with the given subject and audience and nonce
*
* @param subject
* @param audience
* @param validitySeconds
* @param transaction
* @param nonce
* @return
*/
String getIDToken(String subject, String audience, int validitySeconds, OIDCTransaction transaction);
String getIDToken(String subject, String audience, int validitySeconds, OIDCTransaction transaction, String nonce);
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ mosip.esignet.cache.security.secretkey.reference-id=TRANSACTION_CACHE
mosip.esignet.cache.security.algorithm-name=AES/ECB/PKCS5Padding
mosip.esignet.cache.key.hash.algorithm=SHA3-256

mosip.esignet.cache.names=clientdetails,preauth,authenticated,authcodegenerated,userinfo,linkcodegenerated,linked,linkedcode,linkedauth,consented,authtokens,bindingtransaction,apiratelimit,blocked,halted
mosip.esignet.cache.names=clientdetails,preauth,authenticated,authcodegenerated,userinfo,linkcodegenerated,linked,linkedcode,linkedauth,consented,authtokens,bindingtransaction,apiratelimit,blocked,halted,nonce

# 'simple' cache type is only applicable only for Non-Production setup
spring.cache.type=redis
Expand All @@ -199,9 +199,11 @@ mosip.esignet.cache.size={'clientdetails' : 200, \
'bindingtransaction': 200, \
'apiratelimit' : 500, \
'blocked': 500, \
'halted' : 500 }
'halted' : 500,\
'nonce' : 500 }

# Cache expire in seconds is applicable for both 'simple' and 'Redis' cache type
# TTL of 'authtokens' cache depends on the auth token expire time acquired from IAM / MOSIP authmanager.
mosip.esignet.cache.expire-in-seconds={'clientdetails' : 86400, \
'preauth': ${mosip.esignet.preauthentication-expire-in-secs},\
'authenticated': ${mosip.esignet.authentication-expire-in-secs}, \
Expand All @@ -212,11 +214,12 @@ mosip.esignet.cache.expire-in-seconds={'clientdetails' : 86400, \
'linkedcode': ${mosip.esignet.link-code-expire-in-secs}, \
'linkedauth' : ${mosip.esignet.authentication-expire-in-secs}, \
'consented': 600, \
'authtokens': 86400, \
'authtokens': 7200, \
'bindingtransaction': 600, \
'apiratelimit' : 180, \
'blocked': 300, \
'halted' : ${mosip.esignet.signup.halt.expire-seconds} }
'halted' : ${mosip.esignet.signup.halt.expire-seconds}, \
'nonce' : 86400 }

## ------------------------------------------ Discovery openid-configuration -------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.RedisScriptingCommands;
import org.springframework.data.redis.connection.ReturnType;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
Expand All @@ -65,6 +70,7 @@
import java.util.Map;

import static io.mosip.esignet.core.constants.Constants.UTC_DATETIME_PATTERN;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo;
Expand Down Expand Up @@ -121,6 +127,17 @@ public class AuthCodeFlowTest {
public void init() throws Exception {
createOIDCClient(clientId, clientJWK.toPublicJWK(), replyingPartyId);
log.info("Successfully create OIDC Client {}", clientId);

RedisScriptingCommands redisScriptingCommands = Mockito.mock(RedisScriptingCommands.class);
RedisConnection redisConnection = Mockito.mock(RedisConnection.class);
RedisConnectionFactory redisConnectionFactory = Mockito.mock(RedisConnectionFactory.class);
when(redisConnectionFactory.getConnection()).thenReturn(redisConnection);
when(redisConnection.scriptingCommands()).thenReturn(redisScriptingCommands);
when(redisScriptingCommands.evalSha(anyString(), any(ReturnType.class), anyInt(), any(), any())).thenReturn(1L);

ReflectionTestUtils.setField(cacheUtilService, "redisConnectionFactory", redisConnectionFactory);
ReflectionTestUtils.setField(cacheUtilService, "nonceScriptHash", "nonceScriptHash");
ReflectionTestUtils.setField(cacheUtilService, "nonceValidity", 86400);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,13 @@
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.RedisScriptingCommands;
import org.springframework.data.redis.connection.ReturnType;
import org.springframework.http.MediaType;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
Expand All @@ -60,6 +65,9 @@

import static io.mosip.esignet.api.util.ErrorConstants.AUTH_FAILED;
import static io.mosip.esignet.core.constants.Constants.UTC_DATETIME_PATTERN;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

Expand Down Expand Up @@ -104,6 +112,20 @@ public class AuthorizationAPIFlowTest {
private JWK clientJWK = TestUtil.generateJWK_RSA();
private boolean created = false;

@Before
public void init() {
RedisScriptingCommands redisScriptingCommands = Mockito.mock(RedisScriptingCommands.class);
RedisConnection redisConnection = Mockito.mock(RedisConnection.class);
RedisConnectionFactory redisConnectionFactory = Mockito.mock(RedisConnectionFactory.class);
when(redisConnectionFactory.getConnection()).thenReturn(redisConnection);
when(redisConnection.scriptingCommands()).thenReturn(redisScriptingCommands);
when(redisScriptingCommands.evalSha(anyString(), any(ReturnType.class), anyInt(), any(), any())).thenReturn(1L);

ReflectionTestUtils.setField(cacheUtilService, "redisConnectionFactory", redisConnectionFactory);
ReflectionTestUtils.setField(cacheUtilService, "nonceScriptHash", "nonceScriptHash");
ReflectionTestUtils.setField(cacheUtilService, "nonceValidity", 86400);
}


@Test
public void invalidClientId_thenFail() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.mosip.kernel.keymanagerservice.helper.KeymanagerDBHelper;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.util.Pair;
import org.json.JSONObject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
Expand Down Expand Up @@ -232,7 +233,7 @@ protected KycAuthResult handleInternalAuthenticateRequest(@NonNull AuthChallenge
String subject = jwt.getJWTClaimsSet().getSubject();

//compares individual from auth request against subject from jwt token.
if(!individualId.equals(subject)){
if(!individualId.equals(subject)) {
throw new EsignetException(INVALID_INDIVIDUAL_ID);
}

Expand All @@ -241,9 +242,14 @@ protected KycAuthResult handleInternalAuthenticateRequest(@NonNull AuthChallenge
.findFirst();
OIDCTransaction haltedTransaction = cacheUtilService.getHaltedTransaction(subject);

//Checks to confirm that the ID token is not mis-used or re-used
//Validate if cookie is present with token subject as name and halted transaction is present in cache
if(result.isPresent() && haltedTransaction != null && haltedTransaction.getServerNonce().equals(
result.get().getValue().split(SERVER_NONCE_SEPARATOR)[0])) {
//validate if the server nonce in the halted transaction is same as the nonce in the ID token
//validate if the nonce in the ID token is same as the nonce in the current OIDC transaction
if(result.isPresent() && haltedTransaction != null &&
haltedTransaction.getServerNonce().equals(result.get().getValue().split(SERVER_NONCE_SEPARATOR)[0]) &&
haltedTransaction.getServerNonce().equals(jwt.getJWTClaimsSet().getStringClaim(TokenService.NONCE)) &&
transaction.getNonce().equals(jwt.getJWTClaimsSet().getStringClaim(TokenService.NONCE))) {
transaction.setIndividualId(haltedTransaction.getIndividualId());
KycAuthResult kycAuthResult = new KycAuthResult();
kycAuthResult.setKycToken(subject);
Expand Down Expand Up @@ -402,7 +408,7 @@ private String getKeyAlias(String keyAppId, String keyRefId) {
throw new EsignetException(NO_UNIQUE_ALIAS);
}

protected String validateAndGetSubject(String clientId, String idTokenHint) {
protected Pair<String,String> validateAndGetSubjectAndNonce(String clientId, String idTokenHint) {
try {
String[] jwtParts = idTokenHint.split("\\.");
if (jwtParts.length == 3) {
Expand All @@ -411,7 +417,7 @@ protected String validateAndGetSubject(String clientId, String idTokenHint) {
String audience = payloadJson.getString(TokenService.AUD);
if(!signupIDTokenAudience.equals(audience) || !signupIDTokenAudience.equals(clientId))
throw new EsignetException(ErrorConstants.INVALID_ID_TOKEN_HINT);
return payloadJson.getString(TokenService.SUB);
return Pair.of(payloadJson.getString(TokenService.SUB), payloadJson.getString(TokenService.NONCE));
}
} catch (Exception e) {
log.error("Failed to parse the given IDTokenHint as JWT", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,15 @@ public OAuthDetailResponseV2 getOauthDetailsV2(OAuthDetailRequestV2 oauthDetailR
public OAuthDetailResponseV2 getOauthDetailsV3(OAuthDetailRequestV3 oauthDetailReqDto, HttpServletRequest httpServletRequest) throws EsignetException {
//id_token_hint is an optional parameter, if provided then it is expected to be a valid JWT
if (oauthDetailReqDto.getIdTokenHint() != null) {
String subject = authorizationHelperService.validateAndGetSubject(oauthDetailReqDto.getClientId(), oauthDetailReqDto.getIdTokenHint());
Pair<String, String> pair = authorizationHelperService.validateAndGetSubjectAndNonce(oauthDetailReqDto.getClientId(), oauthDetailReqDto.getIdTokenHint());
if(httpServletRequest.getCookies() == null)
throw new EsignetException(ErrorConstants.INVALID_ID_TOKEN_HINT);
Optional<Cookie> result = Arrays.stream(httpServletRequest.getCookies()).filter(x -> x.getName().equals(subject)).findFirst();
Optional<Cookie> result = Arrays.stream(httpServletRequest.getCookies()).filter(x -> x.getName().equals(pair.getFirst())).findFirst();
if (result.isEmpty()) {
throw new EsignetException(ErrorConstants.INVALID_ID_TOKEN_HINT);
}
String[] parts = result.get().getValue().split(SERVER_NONCE_SEPARATOR);
oauthDetailReqDto.setNonce(pair.getSecond());
oauthDetailReqDto.setState(parts.length == 2? parts[1] : result.get().getValue());
}
return getOauthDetailsV2(oauthDetailReqDto);
Expand Down Expand Up @@ -300,7 +301,8 @@ public SignupRedirectResponse prepareSignupRedirect(SignupRedirectRequest signup

SignupRedirectResponse signupRedirectResponse = new SignupRedirectResponse();
signupRedirectResponse.setTransactionId(signupRedirectRequest.getTransactionId());
signupRedirectResponse.setIdToken(tokenService.getIDToken(signupRedirectRequest.getTransactionId(), signupIDTokenAudience, signupIDTokenValidity, oidcTransaction));
signupRedirectResponse.setIdToken(tokenService.getIDToken(signupRedirectRequest.getTransactionId(), signupIDTokenAudience, signupIDTokenValidity,
oidcTransaction, oidcTransaction.getServerNonce()));

//Move the transaction to halted transaction
cacheUtilService.setHaltedTransaction(signupRedirectRequest.getTransactionId(), oidcTransaction);
Expand Down Expand Up @@ -395,6 +397,7 @@ private Pair<OAuthDetailResponse, OIDCTransaction> checkAndBuildOIDCTransaction(
OAuthDetailResponse oAuthDetailResponse) {
log.info("nonce : {} Valid client id found, proceeding to validate redirect URI", oauthDetailReqDto.getNonce());
IdentityProviderUtil.validateRedirectURI(clientDetailDto.getRedirectUris(), oauthDetailReqDto.getRedirectUri());
validateNonce(oauthDetailReqDto.getNonce());

//Resolve the final set of claims based on registered and request parameter.
Claims resolvedClaims = claimsHelperService.resolveRequestedClaims(oauthDetailReqDto, clientDetailDto);
Expand Down Expand Up @@ -503,6 +506,12 @@ private String getAuthTransactionId(String oidcTransactionId) {
return new String(authTransactionIdBytes);
}

private void validateNonce(String nonce) {
if(nonce == null || nonce.isBlank())
return;

if(cacheUtilService.checkNonce(nonce.trim()) == 0L)
throw new EsignetException(ErrorConstants.INVALID_REQUEST);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,43 @@
import io.mosip.esignet.core.util.IdentityProviderUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cache.CacheManager;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.CachePut;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.cache.annotation.Caching;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.ReturnType;
import org.springframework.stereotype.Service;

import java.nio.charset.StandardCharsets;

import static io.mosip.esignet.core.util.IdentityProviderUtil.ALGO_SHA3_256;


@Slf4j
@Service
public class CacheUtilService {

private static final String NONCE_CHECK_SCRIPT = "if redis.call(\"EXISTS\", KEYS[1]) == 0 then\n" +
" redis.call(\"SETEX\", KEYS[1], tonumber(ARGV[1]), \"1\")\n" +
" return 1\n" +
"else\n" +
" return 0\n" +
"end";
private String nonceScriptHash = null;
private static String NONCE_KEY = "nonce::%s";

@Value("${mosip.esignet.nonce-expire-seconds:86400}")
private int nonceValidity;

@Autowired
CacheManager cacheManager;

@Autowired
private RedisConnectionFactory redisConnectionFactory;

@Cacheable(value = Constants.PRE_AUTH_SESSION_CACHE, key = "#transactionId")
public OIDCTransaction setTransaction(String transactionId, OIDCTransaction oidcTransaction) {
return oidcTransaction;
Expand Down Expand Up @@ -75,6 +95,24 @@ public void removeHaltedTransaction(String transactionId) {
log.debug("Evicting entry from HALTED_CACHE");
}

public long checkNonce(String nonce) {
if (redisConnectionFactory.getConnection() != null) {
if (nonceScriptHash == null) {
nonceScriptHash = redisConnectionFactory.getConnection().scriptingCommands().scriptLoad(NONCE_CHECK_SCRIPT.getBytes());
}
log.info("Running NONCE_CHECK_SCRIPT script: {}", nonceScriptHash);
final String key = String.format(NONCE_KEY, nonce);
return redisConnectionFactory.getConnection().scriptingCommands().evalSha(
nonceScriptHash,
ReturnType.INTEGER,
1, // Number of keys
key.getBytes(), // The Redis hash name (key)
String.valueOf(nonceValidity).getBytes(StandardCharsets.UTF_8) // ttl
);
}
return 0;
}

//---------------------------------------------- Linked authorization ----------------------------------------------

@CacheEvict(value = Constants.PRE_AUTH_SESSION_CACHE, key = "#transactionId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.*;

Expand Down Expand Up @@ -92,15 +91,15 @@ public class TokenServiceImpl implements TokenService {
@Override
public String getIDToken(@NonNull OIDCTransaction transaction) {
JSONObject payload = buildIDToken(transaction.getPartnerSpecificUserToken(),
transaction.getClientId(), idTokenExpireSeconds, transaction);
transaction.getClientId(), idTokenExpireSeconds, transaction, null);
payload.put(ACCESS_TOKEN_HASH, transaction.getAHash());
return getSignedJWT(Constants.OIDC_SERVICE_APP_ID, payload);
}

@Override
public String getIDToken(@NonNull String subject, @NonNull String audience, int validitySeconds,
@NonNull OIDCTransaction transaction) {
JSONObject payload = buildIDToken(subject, audience, validitySeconds, transaction);
@NonNull OIDCTransaction transaction, String nonce) {
JSONObject payload = buildIDToken(subject, audience, validitySeconds, transaction, nonce);
return getSignedJWT(Constants.OIDC_SERVICE_APP_ID, payload);
}

Expand Down Expand Up @@ -213,7 +212,7 @@ public String getSignedJWT(String applicationId, JSONObject payload) {
}

private JSONObject buildIDToken(String subject, String audience, int validitySeconds,
OIDCTransaction transaction) {
OIDCTransaction transaction, String nonce) {
JSONObject payload = new JSONObject();
payload.put(ISS, issuerId);
payload.put(SUB, subject);
Expand All @@ -222,7 +221,7 @@ private JSONObject buildIDToken(String subject, String audience, int validitySec
payload.put(IAT, issueTime);
payload.put(EXP, issueTime + (validitySeconds<=0 ? 3600 : validitySeconds));
payload.put(AUTH_TIME, transaction.getAuthTimeInSeconds());
payload.put(NONCE, transaction.getNonce());
payload.put(NONCE, nonce == null ? transaction.getNonce() : nonce);
List<String> acrs = authenticationContextClassRefUtil.getACRs(transaction.getProvidedAuthFactors());
payload.put(ACR, String.join(SPACE, acrs));
return payload;
Expand Down
Loading

0 comments on commit 44948ac

Please sign in to comment.