diff --git a/lib/redis.core-module.ts b/lib/redis.core-module.ts index 2d33e66..b8963cb 100644 --- a/lib/redis.core-module.ts +++ b/lib/redis.core-module.ts @@ -1,14 +1,46 @@ -import { DynamicModule, Module, Global, Provider } from '@nestjs/common'; -import { RedisModuleAsyncOptions, RedisModuleOptions, RedisModuleOptionsFactory } from './redis.interfaces'; -import { createRedisConnection, getRedisOptionsToken, getRedisConnectionToken } from './redis.utils' +import { + DynamicModule, + Module, + Global, + Provider, + OnApplicationShutdown, +} from '@nestjs/common'; +import { + RedisModuleAsyncOptions, + RedisModuleOptions, + RedisModuleOptionsFactory, +} from './redis.interfaces'; +import { + createRedisConnection, + getRedisOptionsToken, + getRedisConnectionToken, + tryCloseRedisConnectionPermanently, +} from './redis.utils'; +import Redis, { Cluster } from 'ioredis'; @Global() @Module({}) -export class RedisCoreModule { +export class RedisCoreModule implements OnApplicationShutdown { + private static readonly redisConnections = [] as Array< + WeakRef + >; + + public async onApplicationShutdown() { + await Promise.all( + RedisCoreModule.redisConnections.map(async (connection) => { + const redis = connection.deref(); + if (redis) { + await tryCloseRedisConnectionPermanently(redis); + } + }), + ); + } /* forRoot */ - static forRoot(options: RedisModuleOptions, connection?: string): DynamicModule { - + static forRoot( + options: RedisModuleOptions, + connection?: string, + ): DynamicModule { const redisOptionsProvider: Provider = { provide: getRedisOptionsToken(connection), useValue: options, @@ -16,29 +48,25 @@ export class RedisCoreModule { const redisConnectionProvider: Provider = { provide: getRedisConnectionToken(connection), - useValue: createRedisConnection(options), + useValue: RedisCoreModule.createAndTrackRedisConnection(options), }; return { module: RedisCoreModule, - providers: [ - redisOptionsProvider, - redisConnectionProvider, - ], - exports: [ - redisOptionsProvider, - redisConnectionProvider, - ], + providers: [redisOptionsProvider, redisConnectionProvider], + exports: [redisOptionsProvider, redisConnectionProvider], }; } /* forRootAsync */ - public static forRootAsync(options: RedisModuleAsyncOptions, connection: string): DynamicModule { - + public static forRootAsync( + options: RedisModuleAsyncOptions, + connection?: string, + ): DynamicModule { const redisConnectionProvider: Provider = { provide: getRedisConnectionToken(connection), useFactory(options: RedisModuleOptions) { - return createRedisConnection(options) + return RedisCoreModule.createAndTrackRedisConnection(options); }, inject: [getRedisOptionsToken(connection)], }; @@ -46,22 +74,27 @@ export class RedisCoreModule { return { module: RedisCoreModule, imports: options.imports, - providers: [...this.createAsyncProviders(options, connection), redisConnectionProvider], + providers: [ + ...this.createAsyncProviders(options, connection), + redisConnectionProvider, + ], exports: [redisConnectionProvider], }; } /* createAsyncProviders */ - public static createAsyncProviders(options: RedisModuleAsyncOptions, connection?: string): Provider[] { - - if(!(options.useExisting || options.useFactory || options.useClass)) { - throw new Error('Invalid configuration. Must provide useFactory, useClass or useExisting'); + public static createAsyncProviders( + options: RedisModuleAsyncOptions, + connection?: string, + ): Provider[] { + if (!(options.useExisting || options.useFactory || options.useClass)) { + throw new Error( + 'Invalid configuration. Must provide useFactory, useClass or useExisting', + ); } if (options.useExisting || options.useFactory) { - return [ - this.createAsyncOptionsProvider(options, connection) - ]; + return [this.createAsyncOptionsProvider(options, connection)]; } return [ @@ -71,10 +104,14 @@ export class RedisCoreModule { } /* createAsyncOptionsProvider */ - public static createAsyncOptionsProvider(options: RedisModuleAsyncOptions, connection?: string): Provider { - - if(!(options.useExisting || options.useFactory || options.useClass)) { - throw new Error('Invalid configuration. Must provide useFactory, useClass or useExisting'); + public static createAsyncOptionsProvider( + options: RedisModuleAsyncOptions, + connection?: string, + ): Provider { + if (!(options.useExisting || options.useFactory || options.useClass)) { + throw new Error( + 'Invalid configuration. Must provide useFactory, useClass or useExisting', + ); } if (options.useFactory) { @@ -87,10 +124,18 @@ export class RedisCoreModule { return { provide: getRedisOptionsToken(connection), - async useFactory(optionsFactory: RedisModuleOptionsFactory): Promise { + async useFactory( + optionsFactory: RedisModuleOptionsFactory, + ): Promise { return await optionsFactory.createRedisModuleOptions(); }, inject: [options.useClass || options.useExisting], }; } + + protected static createAndTrackRedisConnection(options: RedisModuleOptions) { + const redis = createRedisConnection(options); + this.redisConnections.push(new WeakRef(redis)); + return redis; + } } diff --git a/lib/redis.module.spec.ts b/lib/redis.module.spec.ts index 30cfecd..1492f17 100644 --- a/lib/redis.module.spec.ts +++ b/lib/redis.module.spec.ts @@ -5,18 +5,21 @@ import { Test, TestingModule } from '@nestjs/testing'; import { RedisModule } from './redis.module'; import { getRedisConnectionToken } from './redis.utils'; import { InjectRedis } from './redis.decorators'; +import { setTimeout } from 'timers/promises'; describe('RedisModule', () => { it('Instance Redis', async () => { const module: TestingModule = await Test.createTestingModule({ - imports: [RedisModule.forRoot({ - type: 'single', - options: { - host: '127.0.0.1', - port: 6379, - password: '123456', - } - })], + imports: [ + RedisModule.forRoot({ + type: 'single', + options: { + host: '127.0.0.1', + port: 6379, + password: '123456', + }, + }), + ], }).compile(); const app = module.createNestApplication(); @@ -31,20 +34,24 @@ describe('RedisModule', () => { const defaultConnection: string = 'default'; const module: TestingModule = await Test.createTestingModule({ - imports: [RedisModule.forRoot({ - type: 'single', - options: { - host: '127.0.0.1', - port: 6379, - password: '123456', - } - })], - },).compile(); + imports: [ + RedisModule.forRoot({ + type: 'single', + options: { + host: '127.0.0.1', + port: 6379, + password: '123456', + }, + }), + ], + }).compile(); const app = module.createNestApplication(); await app.init(); const redisClient = module.get(getRedisConnectionToken(defaultConnection)); - const redisClientTest = module.get(getRedisConnectionToken(defaultConnection)); + const redisClientTest = module.get( + getRedisConnectionToken(defaultConnection), + ); expect(redisClient).toBeInstanceOf(Redis); expect(redisClientTest).toBeInstanceOf(Redis); @@ -53,7 +60,6 @@ describe('RedisModule', () => { }); it('inject redis connection', async () => { - @Injectable() class TestProvider { constructor(@InjectRedis() private readonly redis: Redis) {} @@ -64,14 +70,16 @@ describe('RedisModule', () => { } const module: TestingModule = await Test.createTestingModule({ - imports: [RedisModule.forRoot({ - type: 'single', - options: { - host: '127.0.0.1', - port: 6379, - password: '123456', - } - })], + imports: [ + RedisModule.forRoot({ + type: 'single', + options: { + host: '127.0.0.1', + port: 6379, + password: '123456', + }, + }), + ], providers: [TestProvider], }).compile(); @@ -83,4 +91,49 @@ describe('RedisModule', () => { await app.close(); }); + + it('closes all redis connections on shutdown', async () => { + const module: TestingModule = await Test.createTestingModule({ + imports: [ + RedisModule.forRoot({ + type: 'single', + options: { + host: '127.0.0.1', + port: 6379, + password: '123456', + }, + }), + RedisModule.forRoot( + { + type: 'single', + options: { + host: '127.0.0.1', + port: 6379, + password: '123456', + }, + }, + 'second', + ), + ], + }).compile(); + + const app = module.createNestApplication(); + await app.init(); + const defaultRedisClient = module.get(getRedisConnectionToken()); + const secondRedisClient = module.get( + getRedisConnectionToken('second'), + ); + + await setTimeout(1000); + + expect(defaultRedisClient.status).toBe('ready'); + expect(secondRedisClient.status).toBe('ready'); + + await app.close(); + + await setTimeout(1000); + + expect(defaultRedisClient.status).toBe('end'); + expect(secondRedisClient.status).toBe('end'); + }); }); diff --git a/lib/redis.utils.ts b/lib/redis.utils.ts index 16ef56e..86204b0 100644 --- a/lib/redis.utils.ts +++ b/lib/redis.utils.ts @@ -1,17 +1,21 @@ -import Redis, { RedisOptions } from 'ioredis'; +import Redis, { Cluster, RedisOptions } from 'ioredis'; import { RedisModuleOptions } from './redis.interfaces'; import { REDIS_MODULE_CONNECTION, REDIS_MODULE_CONNECTION_TOKEN, - REDIS_MODULE_OPTIONS_TOKEN + REDIS_MODULE_OPTIONS_TOKEN, } from './redis.constants'; export function getRedisOptionsToken(connection?: string): string { - return `${ connection || REDIS_MODULE_CONNECTION }_${ REDIS_MODULE_OPTIONS_TOKEN }`; + return `${ + connection || REDIS_MODULE_CONNECTION + }_${REDIS_MODULE_OPTIONS_TOKEN}`; } export function getRedisConnectionToken(connection?: string): string { - return `${ connection || REDIS_MODULE_CONNECTION }_${ REDIS_MODULE_CONNECTION_TOKEN }`; + return `${ + connection || REDIS_MODULE_CONNECTION + }_${REDIS_MODULE_CONNECTION_TOKEN}`; } export function createRedisConnection(options: RedisModuleOptions) { @@ -24,10 +28,23 @@ export function createRedisConnection(options: RedisModuleOptions) { const { url, options: { port, host } = {} } = options; const connectionOptions: RedisOptions = { ...commonOptions, port, host }; - return url ? new Redis(url, connectionOptions) : new Redis(connectionOptions); + return url + ? new Redis(url, connectionOptions) + : new Redis(connectionOptions); default: throw new Error('Invalid configuration'); } } - +export const tryCloseRedisConnectionPermanently = async ( + redis: Redis | Cluster, +) => { + try { + await redis.quit(); + } catch (error) { + if (error instanceof Error && error.message === 'Connection is closed.') { + return; + } + throw error; + } +}; diff --git a/tsconfig.json b/tsconfig.json index f73690d..af12204 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -13,17 +13,8 @@ "sourceMap": false, "outDir": "./dist", "rootDir": "./lib", - "lib": ["es7"] + "lib": ["es7", "ES2021.WeakRef"] }, - "include": [ - "lib/**/*" - ], - "exclude": [ - "node_modules" - ] + "include": ["lib/**/*"], + "exclude": ["node_modules"] } - - - - - \ No newline at end of file