Kafka Streams : pollPhase()

    들어가기 전

    이 글은 개인적으로 Kafka Streams의 코드를 따라가면서 작성한 글입니다. 틀린 부분이 있을 수 있기 때문에 틀린 부분을 알려주시면 다시 한번 공부해서 수정해두겠습니다.

     

    pollPhase()의 전체적인 요약

    카프카 스트림즈는 StreamThread가 모든 작업을 처리해주고 있다. StreamThread는 기본적으로 5개의 단계로 ETL을 한다. 아래 단계에서 확인할 수 있다. 파란색으로 표시해둔 부분은 반복해서 처리하는 부분이다. 이 글에서는 5개의 단계 중 pollPhase()에 대해서 알아보고자 한다. 

    pollPhase() → initializedAndRestorePhase() → [ process() → taskManager.process() → taskManager.punctuate() → Commit() ] 

     

    전체 요약

    1. StreamThread는 pollPhase() 메서드를 호출한다. 이 메서드는 Kafka Streams에서 프로세스 할 데이터를 Consume 하는 과정이다.
    2. pollPhase()가 호출되면 StreamThread는 MainConsumer.poll()을 호출한다. MainConsumer.poll()을 호출하면, Network IO 쓰레드를 통해 브로커에서 가져온 메세지를 역직렬화 한 후 Fetch해서 레코드를 받게 된다.
    3. taskManager.updateTaskEndMetadata()를 호출한다. 이 때 레코드를 함께 넘겨주는데, 각 activeTask의 topicPartition이 가지고 있는 High Water Mark를 넘겨준 레코드의 마지막 offset으로 업데이트한다. 
    4. taskManager.addRecordsToTasks()를 호출해준다. 이 때, TopicPartition으로 필터링해서 각각의 레코드가 적절한 activeTask에게 넘어가게 된다.
    5. 4번의 단계에서 레코드는 다양한 TopicPartition으로 분리되고, 이 TopicPartition은 다양한 StreamTask가 가지고 있다. 따라서 여러 activeTask들에게서 작업이 일어나게 된다.
    6. 각각의 activeTask는 자신의 PartitionGroup을 가지고 있는데, 이 때 PartitionGroup.addRawRecords() 메서드를 호출한다.
    7. addRawRecords() 메서드를 호출하면 PartitionGroup은 PartitionQues에게서 현재 TopicPartition에 대응되는 PartitionQue를 가져온다. PartitionQue는 처리해야할 레코드가 담겨있는 Que다.
    8. PartitionQue를 반환받으면 PartitionGroup은 Que.addRawRecords()를 호출해서 PartitionQue에 전달받은 레코드를 넣어준다. 
    9. 마지막으로 PartitionGroup은 nonEmptyQueByTime.offer() 메서드에 PartitionQue를 넣어준다. nonEmptyQueByTime은 우선순위 큐이고, 이 녀석은 레코드의 시간이 가장 빠른 녀석이 가장 먼저 나오는 형태로 구현되어있다. streamTask는 nonEmptyQueByTime에서 시간 순으로 가장 빨리 처리해야할 레코드 큐를 가져온 후 처리하는 형태로 구현되어있다. 

    위와 같은 단계를 거쳐 pollPhase()는 마무리 된다. pollPhase를 단순히 요약하면 다음과 같다. 

    1. Consumer를 통해서 브로커에서 메세지를 가져온다.
    2. 가져온 메세지를 nonEmptyQueByTime에 넣어주어서 각각의 StreamTask가 사용할 수 있도록 셋팅해준다.

     

    pollPhase() 코드 따라가보기

     

    StreamThread는 pollPhase()를 호출한다. pollPhase() 메서드에서는 3가지 단계로 작업이 진행된다.

    1. pollRequest() 메서드를 호출한다. 이 메서드를 호출하면 mainConsumer.poll()을 이용해서 메세지를 가져오고, 그 결과를 records에 저장한다.
    2. records.~.ifPresent() 메서드를 호출한다. 이 메서드를 호출하면 각 StreamTask가 가지고 있는 TopicPartition의 하이워터마크를 이번에 전달받은 레코드의 마지막 offset 값으로 업데이트 해준다.
    3. taskManager.addRecordsToTask()를 호출한다. 이 메서드를 호출하면 각 StreamTask가 가지고 있는 파티션 큐에 전달받은 레코드를 적절히 넣어주는 작업을 진행한다. 

    이 때, 같이 보면 좋을 부분은 현재 Kafka Streams의 상태에 따라 pollRequests()를 할 때의 인자가 다르다는 것이다. Running 상태가 아닐 때는 Duration.zero로 바로 poll() 요청을 하도록 해서 최대한 빨리 첫번째 데이터를 받아오도록 한다. Running일 경우에는 이미 연결 되어있기 때문에 pollTime에 대해서 메세지를 가져오도록 처리된다. 

    // StreamThread.java
    private long pollPhase() {
        final ConsumerRecords<byte[], byte[]> records;
        ...
    
        if (state == State.PARTITIONS_ASSIGNED) {
            // 메세지 가져오기
            records = pollRequests(Duration.ZERO);
        } else if (state == State.PARTITIONS_REVOKED) {
        	// 메세지 가져오기
            records = pollRequests(Duration.ZERO);
        } else if (state == State.RUNNING || state == State.STARTING) {
        	// 메세지 가져오기
            records = pollRequests(pollTime);
        } else if (state == State.PENDING_SHUTDOWN) {
            // 메세지 가져오기
            records = pollRequests(Duration.ZERO);
        } else {
            ...
            throw new StreamsException(logPrefix + "Unexpected state " + state + " during normal iteration");
        }
    
        ...
    
        final int numRecords = records.count();
    
        for (final TopicPartition topicPartition: records.partitions()) {
            // 하이워터마크 업데이트
            records
                .records(topicPartition)
                .stream()
                .max(Comparator.comparing(ConsumerRecord::offset))
                .ifPresent(t -> taskManager.updateTaskEndMetadata(topicPartition, t.offset()));
        }
    
        ...
    
        if (!records.isEmpty()) {
            pollRecordsSensor.record(numRecords, now);
            // 파티션 큐에 메세지 넣기
            taskManager.addRecordsToTasks(records);
        }
    
        ...
    }

    StreamThread의 pollRequests()가 호출된다. 이곳에서는 mainConsumer를 이용해서 브로커에게서 메세지를 읽어오고, 읽어온 메세지를 반환하는 작업을 한다. 

    // StreamThread.java
    private ConsumerRecords<byte[], byte[]> pollRequests(final Duration pollTime) {
        ConsumerRecords<byte[], byte[]> records = ConsumerRecords.empty();
    
        lastPollMs = now;
    
        try {
            records = mainConsumer.poll(pollTime);
        } catch (final InvalidOffsetException e) {
            resetOffsets(e.partitions(), e);
        }
    
        return records;
    }

    TaskManager.updateTaskEndMetadata()가 호출된다. 이곳에서는 각 StreamTask에게 for문을 돌면서 updateEndOffsets()를 호출하고 이 때 topicPartition, offset을 인자로 넘겨준다. 이 때 각 StreamTask가 가지고 있는 각 topicPartition의 하이워터마크가 offset으로 저장되게 된다. 

    // TaskManager.java
    public void updateTaskEndMetadata(final TopicPartition topicPartition, final Long offset) {
        for (final Task task : tasks.activeTasks()) {
            if (task instanceof StreamTask) {
                if (task.inputPartitions().contains(topicPartition)) {
                    ((StreamTask) task).updateEndOffsets(topicPartition, offset);
                }
            }
        }
    }

    TaskManager.addRecordsToTasks() 메서드를 호출한다. 이 메서드에서는 레코드의 각 토픽 파티션에 맞는 StreamTask를 찾아온다. 찾아온 StreamTask의 addRecords() 메서드를 호출하는데, 이 때 각 토픽 파티션에 맞는 레코드를 넘겨준다. 이 메서드를 호출하게 되면 브로커에서 읽어온 레코드가 Processor에서 처리할 수 있도록 저장되게 된다. 

    // TaskManager.java
    void addRecordsToTasks(final ConsumerRecords<byte[], byte[]> records) {
        for (final TopicPartition partition : records.partitions()) {
            final Task activeTask = tasks.activeTasksForInputPartition(partition);
    
            if (activeTask == null) {
                log.error("Unable to locate active task for received-record partition {}. Current tasks: {}",
                    partition, toString(">"));
                throw new NullPointerException("Task was unexpectedly missing for partition " + partition);
            }
    
            activeTask.addRecords(partition, records.records(partition));
        }
    }

    StreamTask.addRecords()로 넘어오게 된다. 이 곳에서는 크게 두 가지 작업을 한다.

    1. partitionGroup.addrawRecords()를 호출한다. 이 메서드를 호출하면 Processor가 작업할 수 있도록 읽어온 레코드를 셋팅해주는 작업을 한다.
    2. 만약 현재 작업대에 올라간 레코드의 사이즈(현재 ActiveTask가 가진 maxBufferedSize보다 큰 경우, mainConsumer가 poll() 해오는 것을 잠시 멈추도록 pause() 메서드를 호출해준다. 
    // StreamTask.java
    public void addRecords(final TopicPartition partition, final Iterable<ConsumerRecord<byte[], byte[]>> records) {
        final int newQueueSize = partitionGroup.addRawRecords(partition, records);
    
        ...
    
        if (newQueueSize > maxBufferedSize) {
            mainConsumer.pause(singleton(partition));
        }
    }

    PartitionGroup.addRawRecords() 메서드가 호출된다. 이 메서드는 세 가지 단계로 작업을 진행한다.

    1. PartitionQues(PartitionQue의 Collection) 객체에게서 해당 partition에 맞는 레코드 큐를 받아온다. 
    2. RecordQue.addRawRecords() 메서드를 호출해서 받아온 메세지를 레코드 큐에 넣어준다. 
    3. nonEmptyQueuesByTime.offer() 메서드를 호출해서 레코드 큐를 넣어준다.

    레코드 큐는 내부적으로 FifoQue라는 Que를 가지고 있는데, 이곳에 받아온 메세지를 저장할 수 있다. 이곳에 받아온 메세지를 저장한 후에 이 레코드 큐 자체를 nonEmptyQueByTime이라는 곳에 넣어준다. nonEmptyQueByTime은 우선순위 큐이며, 레코드 큐들중 가장 빠른 시간의 레코드가 존재하는 큐가 가장 먼저 나오는 형태로 구현되어있다. 각 StreamTask는 Processor가 작업을 진행할 때 nonEmptyQueByTime에서 레코드 큐를 뽑아서, 레코드 큐의 레코드를 하나 처리하고 다시 우선순위 큐에 넣는 방식으로 작업이 진행된다. 

    // partition Group
    int addRawRecords(final TopicPartition partition, final Iterable<ConsumerRecord<byte[], byte[]>> rawRecords) {
        final RecordQueue recordQueue = partitionQueues.get(partition);
    
        ...
    
        final int oldSize = recordQueue.size();
        final int newSize = recordQueue.addRawRecords(rawRecords);
    
        
        if (oldSize == 0 && newSize > 0) {
            nonEmptyQueuesByTime.offer(recordQueue);
    
            if (nonEmptyQueuesByTime.size() == this.partitionQueues.size()) {
                allBuffered = true;
            }
        }
    
        totalBuffered += newSize - oldSize;
        return newSize;
    }

     

    따라가면서 알았던 부분

    1. 각 StreamTask는 PartitionGroup, RecordQue, NonEmptyQuesByTime을 가지고 있어서, 다중 StreamThread의 병렬처리에서 안전할 수 있다.
    2. maxBufferSize는 각 StreamTask별로 결정된다. maxBufferSize는 각 StreamTask가 가지고 있는 레코드 큐의 전체 총합을 제한하는 인자다. 

     

     

     

    댓글

    Designed by JB FACTORY