본문 바로가기

실전코드/Spring, Java

[MyBatis] MyBatis Interceptor를 활용한 페이징(페이지네이션) 처리 : Pageable 객체를 파라미터로

728x90
반응형

개요

기존 MyBatis에서 페이징 처리를 위해선 페이징이 필요한 조회구문에 Top n Query를 직접 작성하거나, <sql>을 사용해서 미리 작성된 구문을 연결해서 붙여줘야 했다.

하지만 이런 방법은 여러 개발자가 수많은 조회구문에 작성하기에는 코드 중복이 심하고 가독성이 떨어져 유지보수가 힘들어질 수 있다는 단점이 있다.

이러한 단점을 해결하기 위해 MyBatis Interceptor를 사용해서 특정 파라미터를 전달받은 select 쿼리에 대해 페이징 처리를 할 수 있는데 이를 알아보고자 한다.

전달받을 파라미터로는 개발자 임의로 정할 수 있으나, JPA에서 사용하는 java 표준인 Pageable 객체를 사용했다.

접근방법

  • MyBatis Interceptor 클래스를 생성한 후 intercept를 Override해준다.
  • intercept의 매개변수인 invocation에서 Pageable을 파라미터로 전달받았는지 검사한다.
  • Pageable을 파라미터로 전달받은 경우 해당 select 구문의 count를 구하는 쿼리를 실행하고, 페이징된 결과값을 반환한다.

코드작성

Interceptor 생성

@Intercepts({
        @Signature(
                type = Executor.class,
                method = "query",
                args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
        )
})
public class MybatisPaginationInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 기본 응답 반환
        return invocation.proceed();
    }
}

@Signature 어노테이션

  • type : 인터셉터를 적용할 MyBatis 클래스의 타입. Executor.class는 Mybatis에서 SQL 실행을 담당하는 클래스
  • method : 인터셉터가 적용될 메서드 이름
  • args : 메서드의 파라미터 타입 배열
    • MappedStatement.class : sql id, sql 구문, 파라미터 등을 담고있는 객체
    • RowBounds.class : 페이징 처리를 할 수 있는 객체

Pageable 파라미터 여부 확인

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        // args[1] : 파라미터(두개 이상일 경우 리스트, 한개일 경우 해당객체)

        Pageable pageable = null;
        
        if(args[1] instanceof Map<?, ?> argMap
                && argMap.containsKey("pageable")
                && argMap.get("pageable") instanceof Pageable) {
            // 여러 파라미터 중 Pageable pageable이 포함됐을 경우
            pageable = (Pageable) argMap.get("pageable");
        } else if (args[1] instanceof Pageable) {
            // 단일 파라미터로 Pageable pageable이 전달됐을 경우
            pageable = (Pageable) args[1];
        }
        
        if (pageable != null) {
        	// 페이징 처리
        }

        // 기본 응답 반환
        return invocation.proceed();
    }

count 쿼리 실행

        // 카운트 쿼리 실행 전에 기존 쿼리정보 반환
        MappedStatement mappedStatement = (MappedStatement) args[0];
        
        if (pageable != null) {
            // count 쿼리 실행
            List<ResultMap> countResultMaps = new ArrayList<>();
            ResultMap countResultMap = new ResultMap.Builder(
                    mappedStatement.getConfiguration(),
                    mappedStatement.getId()+ "Count",
                    Integer.class, new ArrayList<>()
            ).build();
            countResultMaps.add(countResultMap);
            
            // Top n Query 추가
            BoundSql boundSql = mappedStatement.getBoundSql(args[1]);
            String mappedStatementSql= boundSql.getSql();
            StringBuilder sb = new StringBuilder("SELECT COUNT(1) FROM ( ");
            sb.append(mappedStatementSql);
            sb.append(" )  COUNT_TBL");

            BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(),
                    sb.toString(),
                    boundSql.getParameterMappings(),
                    boundSql.getParameterObject()
            );
            
            // <foreach> 구문에서 list로 받은 파라미터를 사용하기 위한 매핑 작업
            if (boundSql.getAdditionalParameters() != null) {
                boundSql.getAdditionalParameters().forEach(countBoundSql::setAdditionalParameter);
            }

            invocation.getArgs()[0] = new MappedStatement
                    .Builder(
                        mappedStatement.getConfiguration()
                        , mappedStatement.getId()+ "Count"
                        , param->countBoundSql
                        , mappedStatement.getSqlCommandType()
                    )
                    .resource(mappedStatement.getResource())
                    .parameterMap(mappedStatement.getParameterMap())
                    .resultMaps(countResultMaps)
                    .fetchSize(mappedStatement.getFetchSize())
                    .timeout(mappedStatement.getTimeout())
                    .statementType(mappedStatement.getStatementType())
                    .resultSetType(mappedStatement.getResultSetType())
                    .cache(mappedStatement.getCache())
                    .useCache(true)
                    .flushCacheRequired(mappedStatement.isFlushCacheRequired())
                    .resultOrdered(mappedStatement.isResultOrdered())
                    .keyGenerator(mappedStatement.getKeyGenerator())
                    .keyColumn(mappedStatement.getKeyColumns() != null ? String.join(",", mappedStatement.getKeyColumns()) : null)
                    .keyProperty(mappedStatement.getKeyProperties() != null ? String.join(",", mappedStatement.getKeyProperties()): null)
                    .databaseId(mappedStatement.getDatabaseId())
                    .lang(mappedStatement.getLang())
                    .resultSets(mappedStatement.getResultSets() != null ? String.join(",", mappedStatement.getResultSets()): null)
                    .build();
            invocation.getArgs()[2] = new RowBounds();
            // 수정한 쿼리 실행
            List<Integer> totalCount = (List<Integer>) invocation.proceed();
        }

select 쿼리 및 페이징 실행

        if (pageable != null) {
            // count 쿼리 실행
            // -----
            
            // select 쿼리 실행
            invocation.getArgs()[0] = mappedStatement;
            // 페이징 정보 담은 RowBounds 반환
            invocation.getArgs()[2] = new RowBounds((int) pageable.getOffset(), pageable.getPageSize());
            List<Object> proceedList = (List<Object>) invocation.proceed(); 
        }

결과 반환

        if (pageable != null) {
            // count 쿼리 실행
            // -----
            
            // select 쿼리 실행
            // -----

            int totalElements = totalCount.isEmpty() ? 0 : totalCount.get(0);  
            
            List<Object> resultList = new ArrayList<>();
            // Page 객체에 결과를 담아 List로 반환
            resultList.add(new PageImpl<>(proceedList, pageable, totalElements));
            return resultList;
        }

여기서 MyBatis의 한계 중 한가지가 드러난다.

select 구문은 mapper에서 List 자료형으로만 결과가 반환되는것. 따라서 Page객체에 반환된 정보를 담기 위해 임의의 리스트에 받아서 반환했다.

따라서 기존에 Mapper에서 선언한 List<ResultType.class>의 자료형에 맞출 수 없다.

리팩터링 결과

@Intercepts({
        @Signature(
                type = Executor.class,
                method = "query",
                args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
        )
})
public class MybatisPaginationInterceptor implements Interceptor {
    private static final String PAGEABLE_PARAM_NAME = "pageable";

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // pageable 추출
        Pageable pageable = extractPageable(invocation.getArgs());

        if (pageable != null) {
            return getPaginationResponses(invocation, pageable);
        }

        // 기본 응답 반환
        return invocation.proceed();
    }

    private Pageable extractPageable(Object[] args) {
        Pageable pageable = null;

        if(args[1] instanceof Map<?, ?> argMap
                && argMap.containsKey(PAGEABLE_PARAM_NAME)
                && argMap.get(PAGEABLE_PARAM_NAME) instanceof Pageable) {
            // 여러 파라미터 중 Pageable pageable이 포함됐을 경우
            pageable = (Pageable) argMap.get(PAGEABLE_PARAM_NAME);
        } else if (args[1] instanceof Pageable) {
            // 단일 파라미터로 Pageable pageable이 전달됐을 경우
            pageable = (Pageable) args[1];
        }
        return pageable;
    }

    private List<?> getPaginationResponses(Invocation invocation, Pageable pageable)
            throws InvocationTargetException, IllegalAccessException {
        Object[] args = invocation.getArgs();
        // args[0]의 초기 참조를 보존하여 원래 쿼리로 복구하는데 사용
        MappedStatement mappedStatement = (MappedStatement) args[0];

        // count 쿼리 실행
        int totalCount = getTotalCount(invocation, args);
        // select 쿼리 실행
        List<Object> proceedList = getProceedList(invocation, pageable, mappedStatement);

        List<Object> resultList = new ArrayList<>();
        resultList.add(new PageImpl<>(proceedList, pageable, totalCount));
        return resultList;
    }

    private int getTotalCount(Invocation invocation, Object[] args)
            throws InvocationTargetException, IllegalAccessException {
        invocation.getArgs()[0] = createCountMappedStatement(args);
        invocation.getArgs()[2] = new RowBounds();
        List<Integer> totalCount = (List<Integer>) invocation.proceed();
        return totalCount.isEmpty() ? 0 : totalCount.get(0);
    }

    private static List<Object> getProceedList(Invocation invocation, Pageable pageable, MappedStatement mappedStatement)
            throws InvocationTargetException, IllegalAccessException {
        invocation.getArgs()[0] = mappedStatement;
        // 페이징 정보 담은 RowBounds 반환
        invocation.getArgs()[2] = new RowBounds((int) pageable.getOffset(), pageable.getPageSize());

        List<Object> proceedList = (List<Object>) invocation.proceed();
        return proceedList;
    }

    private MappedStatement createCountMappedStatement(Object[] args){
        MappedStatement mappedStatement = (MappedStatement) args[1];
        List<ResultMap> countResultMaps = new ArrayList<>();
        ResultMap countResultMap = new ResultMap.Builder(
                mappedStatement.getConfiguration(),
                mappedStatement.getId()+ "Count",
                Integer.class, new ArrayList<>()
        ).build();
        countResultMaps.add(countResultMap);

        BoundSql boundSql = mappedStatement.getBoundSql(args[1]);
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(),
                createCountQuery(boundSql),
                boundSql.getParameterMappings(),
                boundSql.getParameterObject()
        );

        // <foreach> 구문에서 list로 받은 파라미터를 사용하기 위한 매핑 작업
        if (boundSql.getAdditionalParameters() != null) {
            boundSql.getAdditionalParameters().forEach(countBoundSql::setAdditionalParameter);
        }

        return new MappedStatement.Builder(mappedStatement.getConfiguration(),
                mappedStatement.getId()+ "Count",
                param->countBoundSql,
                mappedStatement.getSqlCommandType())
                .resource(mappedStatement.getResource())
                .parameterMap(mappedStatement.getParameterMap())
                .resultMaps(countResultMaps)
                .fetchSize(mappedStatement.getFetchSize())
                .timeout(mappedStatement.getTimeout())
                .statementType(mappedStatement.getStatementType())
                .resultSetType(mappedStatement.getResultSetType())
                .cache(mappedStatement.getCache())
                .useCache(true)
                .flushCacheRequired(mappedStatement.isFlushCacheRequired())
                .resultOrdered(mappedStatement.isResultOrdered())
                .keyGenerator(mappedStatement.getKeyGenerator())
                .keyColumn(mappedStatement.getKeyColumns() != null ? String.join(",", mappedStatement.getKeyColumns()) : null)
                .keyProperty(mappedStatement.getKeyProperties() != null ? String.join(",", mappedStatement.getKeyProperties()): null)
                .databaseId(mappedStatement.getDatabaseId())
                .lang(mappedStatement.getLang())
                .resultSets(mappedStatement.getResultSets() != null ? String.join(",", mappedStatement.getResultSets()): null)
                .build();
    }

    private String createCountQuery(BoundSql boundSql) {
        // Top n Query 추가
        String mappedStatementSql= boundSql.getSql();
        StringBuilder sb = new StringBuilder("SELECT COUNT(1) FROM ( ");
        sb.append(mappedStatementSql);
        sb.append(" )  COUNT_TBL");

        return sb.toString();
    }
}

사용 테스트

@Mapper
public class TestMapper {
    List<TestDto> findAll();
    List<TestDto> findAll(Pageable pageable);
}
@SpringBootTest
public class MybatisInterceptorTest{
    @Autowired
    private TestMapper testMapper;
    
    @Test
    test() {
        Pageable pageable = new PageRequest.of(0, 5);
        List<Object> findAll = testMapper.findAll(pageable);
        
        assertThat(findAll)
                .isNotEmpty()
                .first()
                .isInstanceOf(Page.class);
    }
}

Over Loading에 따라 파라미터가 없으면 전체 데이터를, 파라미터가 있으면 첫번째 5개의 데이터를 조회해서 Page 객체를 List에 담아 반환하고 있다.

이 List에서 Page를 추출하는 Util 메서드나 공통 반환 객체 등을 생성해서 원하는 대로 사용하면 된다.

추가 참고사항

이 상태로 사용할 수도 있으나 발생할 수 있는 가장 불편한 점은 파라미터가 pageable 외에 하나 더 있을 때이다.

보통 파라미터를 하나만 넘기면 parameter name을 설정하지 않고 바로 매핑해서 사용하는데, 파라미터가 두개 이상일 경우 ${param1.data}, ${param2.data} 와 같은 형태로 구분해서 사용한다.

pageable을 전달하지 않을 때와 전달할 때 개발자가 이를 신경쓰지 않고 사용하기 위해선 추가적인 조작을 필요로 한다.

 

이 때 핵심은 args[1]을 리스트에서 단일 객체로 바꿔주는 것이다.

 

Mapper 클래스에서 xml 쿼리를 실행할 때 전달한 파라미터는 args[1]에 저장되어 사용되게 된다.

이 때 파라미터가 하나의 객체만 전달될 경우 args[1]은 전달된 객체 자체가 된다.

하지만 파라미터가 두개 이상 전달될 경우 args[1]은 Map으로 바뀌는데, 이 때 데이터의 개수는 (파라미터 개수 *2)이다.

만약 파라미터로 Pageable pageable, SearchDto searchDto 두개를 전달했다면, args[1]의 키는 pageable, param1, searchDto, param2가 되며, 파라미터의 name을 지정할 경우 pageable, searchDto에 해당하는 키가 지정한 name으로 생성된다.

따라서 개발자들이 Pageable 파라미터를 추가로 전달해도 이를 무시하고 단일 파라미터처럼 사용하기 위해선 args[1]의 데이터의 개수, key를 구성하는 문자열 등을 고려해서 pageable을 제외한 나머지 객체를 추출한 후 args[1]이 단일 객체를 참조하게끔 수정해 줘야 한다.

 

* 이는 개발 지침에 따라 name을 반드시 지정 및 사용해야 할 수도 있고, 이 게시물의 본래 목적과 다르기 때문에 코드는 생략합니다.

728x90
반응형