1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  package org.apache.log4j;
18  
19  import java.lang.ref.Reference;
20  import java.lang.reflect.Field;
21  import java.lang.reflect.Method;
22  import junit.framework.TestCase;
23  
24  
25  
26  
27  
28  
29  public class MDCTestCase extends TestCase {
30  
31    public void setUp() {
32      MDC.clear();
33    }
34  
35    public void tearDown() {
36      MDC.clear();
37    }
38  
39    public void testPut() throws Exception {
40      MDC.put("key", "some value");
41      assertEquals("some value", MDC.get("key"));
42      assertEquals(1, MDC.getContext().size());
43    }
44    
45    public void testRemoveLastKey() throws Exception {
46      MDC.put("key", "some value");
47  
48      MDC.remove("key");
49      checkThreadLocalsForLeaks();
50    }
51  
52    private void checkThreadLocalsForLeaks() throws Exception {
53  
54        
55  
56        
57        Field threadLocalsField = Thread.class.getDeclaredField("threadLocals");
58        threadLocalsField.setAccessible(true);
59        Field inheritableThreadLocalsField = Thread.class.getDeclaredField("inheritableThreadLocals");
60        inheritableThreadLocalsField.setAccessible(true);
61        
62        Class tlmClass = Class.forName("java.lang.ThreadLocal$ThreadLocalMap");
63        Field tableField = tlmClass.getDeclaredField("table");
64        tableField.setAccessible(true);
65  
66        Thread thread = Thread.currentThread();
67  
68        Object threadLocalMap;
69        threadLocalMap = threadLocalsField.get(thread);
70        
71        checkThreadLocalMapForLeaks(threadLocalMap, tableField);
72        
73        threadLocalMap = inheritableThreadLocalsField.get(thread);
74        checkThreadLocalMapForLeaks(threadLocalMap, tableField);
75  
76    }
77  
78    private void checkThreadLocalMapForLeaks(Object map, Field internalTableField) 
79            throws IllegalAccessException, NoSuchFieldException {
80      if (map != null) {
81        Object[] table = (Object[]) internalTableField.get(map);
82        if (table != null) {
83          for (int j =0; j < table.length; j++) {
84            if (table[j] != null) {
85  
86              
87              Object key = ((Reference) table[j]).get();
88              String keyClassName = key.getClass().getName();
89  
90              if (key.getClass() == org.apache.log4j.helpers.ThreadLocalMap.class) {
91                fail("Found a ThreadLocal with key of type [" + keyClassName + "]");
92              }
93            }
94          }
95        }
96      }
97    }
98  }