之前我们通过讨论ReentrantLock学习到了AQS的核心、公平与非公平锁的实现以及Condition的实现原理。但是之前所涉及到的都是非共享锁,也就是独占锁。今天我们来讨论基于AQS的共享模式实现的CountDownLatch组件。
本文大体上会分为两部分进行讨论。第一部分为介绍CountDownLatch的使用,第二部分将通过源码来分析CountDownLatch的实现原理。

1. CountDownLatch的使用

CountDownLatch是一个使用频率非常高的类, 是AQS共享模式的典型应用。它的名字翻译过来为:倒计时门闩。具体的使用如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

class Driver2 { // ...
void main() throws InterruptedException {
CountDownLatch doneSignal = new CountDownLatch(N);
Executor e = Executors.newFixedThreadPool(8);

for (int i = 0; i < N; ++i) // create and start threads
e.execute(new WorkerRunnable(doneSignal, i));

doneSignal.await(); // wait for all to finish
}
}

class WorkerRunnable implements Runnable {
private final CountDownLatch doneSignal;
private final int i;

WorkerRunnable(CountDownLatch doneSignal, int i) {
this.doneSignal = doneSignal;
this.i = i;
}

public void run() {
try {
doWork(i);
// 这个线程的任务完成了,调用 countDown 方法
doneSignal.countDown();
} catch (InterruptedException ex) {
} // return;
}

void doWork() { ...}
}

以上代码是CountDownLatch源码中JavaDoc中的示例代码。不难理解,代码中的逻辑为创建了一个线程池用于执行任务。主线程等待,直到N任务全部都执行,再对主线程放行。

因此,我们可以得知它的使用场景可以是讲一个任务拆分为多个任务,让多个线程来并行执行,直到所有任务完成后,再向下执行。但是这个例子并没有完全展示出CountDownLatch的特性。接下来再看一个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Driver { // ...
void main() throws InterruptedException {
CountDownLatch startSignal = new CountDownLatch(1);
CountDownLatch doneSignal = new CountDownLatch(N);

for (int i = 0; i < N; ++i) // create and start threads
new Thread(new Worker(startSignal, doneSignal)).start();

// 这边插入一些代码,确保上面的每个线程先启动起来,才执行下面的代码。
doSomethingElse(); // don't let run yet

// 因为这里 N == 1,所以,只要调用一次,那么所有的 await 方法都可以通过
startSignal.countDown(); // let all threads proceed

doSomethingElse();
// 等待所有任务结束
doneSignal.await(); // wait for all to finish
}
}

class Worker implements Runnable {
private final CountDownLatch startSignal;
private final CountDownLatch doneSignal;

Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
this.startSignal = startSignal;
this.doneSignal = doneSignal;
}

public void run() {
try {
// 为了让所有线程同时开始任务,我们让所有线程先阻塞在这里
// 等大家都准备好了,再打开这个门栓
startSignal.await();
doWork();
doneSignal.countDown();
} catch (InterruptedException ex) {
} // return;
}

void doWork() { ...}
}

以上代码中,整体逻辑是先等所有的线程启动后,再开始执行任务,然后等所有任务都执行完了,main线程再继续向下执行。可以理解为,有N个线程被一个栅栏阻塞住,只有当通过条件达到了,再打开栅栏放行。注意,放行后,N个线程都被放行了。

有点类似于短跑比赛,首先等所有人准备好了再开始跑,等所有人跑完全程了才能结束比赛。

2. 源码分析

2.1 整体结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
public class CountDownLatch {

private final Sync sync;

public CountDownLatch(int count) {
// count必须大于0
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

public void countDown() {
sync.releaseShared(1);
}

public long getCount() {
return sync.getCount();
}

public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}

private static final class Sync extends AbstractQueuedSynchronizer{
// ...
}
}

可见CountDownLatch的源码并不多,包括Sync的源码满打满算也就300来行。不得不感叹,Doug Lea的设计能力,把设计模式的精髓发挥到极致,抽象出的AQS能简单快速的实现一个同步组件。

2.2 await()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
public void await() throws InterruptedException {  
sync.acquireSharedInterruptibly(1);
}

public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 可中断方法,先判断中断状态,如果被中断了直接抛出中断异常即可。
if (Thread.interrupted())
throw new InterruptedException();

// 先判断是否达到放行条件。
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}

// 只要 state 不等于 0,那么这个方法返回 -1
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

// 整体流程和之前ReentrantLock时已经了解的差不多了
private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {

// 先往同步队列中如一个当前线程的节点
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
// 如果其前驱节点为head,则开始尝试获取锁
if (p == head) {
// 判断state的值
int r = tryAcquireShared(arg);
// 只有r==0时满足条件
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}

// 判断是否应该挂起,顺便给node找一个未取消的前驱节点,并确保其ws==SIGNAL
if (shouldParkAfterFailedAcquire(p, node) &&
// 将线程挂起
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

2.3 countDown()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
public void countDown() {  
sync.releaseShared(1);
}

public final boolean releaseShared(int arg) {
// state为0时返回true
if (tryReleaseShared(arg)) {
// 执行放行的操作
doReleaseShared();
return true;
}
return false;
}

// state减1,并返回是否需要放行。只有state为0时才会返回true
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}

private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
// 如果head的ws为SIGNAL
if (ws == Node.SIGNAL) {
// 如果对h的wd进行CAS操作失败则进行下一轮循环
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
// 唤醒后继节点
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}

// 唤醒后继节点
private void unparkSuccessor(Node node) {
int ws = node.waitStatus;
if (ws < 0)
compareAndSetWaitStatus(node, ws, 0);

Node s = node.next;
if (s == null || s.waitStatus > 0) {
s = null;
for (Node t = tail; t != null && t != node; t = t.prev)
if (t.waitStatus <= 0)
s = t;
}
if (s != null)
LockSupport.unpark(s.thread);
}

根据以上代码中,unparkSuccessor(h)方法唤醒了后继节点的线程。现在再返回去看代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {  
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
// 当state为0时进入。
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}

if (shouldParkAfterFailedAcquire(p, node) &&
// 之前是在这里将线程挂起的。唤醒后进入下一轮循环
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}


private void setHeadAndPropagate(Node node, int propagate) {
Node h = head; // Record old head for check below
setHead(node);

// 唤醒所有的后继节点
if (propagate > 0 || h == null || h.waitStatus < 0 ||
(h = head) == null || h.waitStatus < 0) {
Node s = node.next;
if (s == null || s.isShared())
doReleaseShared();
}
}

此时再看唤醒部分的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
private void doReleaseShared() {  
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}

// 如果到这里的时候,前面唤醒的线程已经占领了 head,那么再循环
// 否则,就是 head 没变,那么退出循环,
// 退出循环是不是意味着阻塞队列中的其他节点就不唤醒了?当然不是,唤醒的线程之后还是会调用这个方法的
if (h == head) // loop if head changed
break;
}
}

关于最后的if (h == head)语句的理解:

  • h == head时:说明头节点还没有被刚刚用 unparkSuccessor 唤醒的线程占有,此时 break 退出循环。
  • h != head时:头节点被刚刚唤醒的线程占有,那么这里重新进入下一轮循环,唤醒下一个节点。那么有一个问题,刚才被唤醒的节点会主动唤醒它后面的节点,为什么这里还要再下一轮中循环呢?我觉得这里应该是处于吞吐量的考虑(帮忙唤醒)。