`

试用CountDownLatch的副成品,多线程群发邮件小程序

 
阅读更多
简要说明
1、基于javamail所做的批量发邮件小程序,SingleSender是普通发送,ThreadSender是多线程发送,MysqlSender是从数据库获取数据再调用ThreadSender来发送。
2、为啥写这个小程序呢?没啥特别的,只因最近想起thinking in java里面介绍的CountDownLatch类,为了练手多线程相关的技术而写了这个小程序,整个程序于我而言核心在ThreadSender里面的CountDownLatch。
3、jar依赖情况
1) 必须:activation.jar, mail.jar, commons-logging.jar, log4j-xx.jar, zTool-xx.jar(我写的另外一些小工具类,源码在 https://z-tool.googlecode.com/svn/trunk/zTool)
2) 可能需要:mysql-connector-java-xx.jar,velocity-xx-dep.jar
(注:上面版本号用xx代替了,依据实际情况找个合适的版本即可)

下面贴出SingleSender和ThreadSender的源码。详细源码在 https://github.com/auzll/zBatchSender


public final class SingleSender extends AbstractSender {
	private static final Log LOG = LogFactory.getLog(SingleSender.class);
	
	/**
	 * 发送一批邮件
	 * @param entities 邮件实体列表
	 */
	public final List<MailEntity> send(List<MailEntity> entities) {
		int entitiesSize = null != entities ? entities.size() : 0;
		if (LOG.isInfoEnabled()) {
			LOG.info("method:send,entitiesSize:" + entitiesSize
					+ ",senderParam:[" + this.toSimpleLog() + "]"
					+ ",thread:" + Thread.currentThread().toString());
		}
		
		if (entitiesSize < 1) {
			return entities;
		}
		
		/** TODO Test code begin */
		// 模拟测试一下好了,不做实际发信
		for (MailEntity entity : entities) {
			try {
				TimeUnit.SECONDS.sleep(new Random().nextInt(3));
			} catch (InterruptedException e) {
			}
			
			if (LOG.isInfoEnabled()) {
				LOG.info("method:send,result:success,entity:[" 
						+ entity.toSimpleLog() + "],thread:" 
						+ Thread.currentThread().toString());
			}
		}
		if (true) return entities;
		
		/** TODO Test code end */
		
		Transport transport = null;
		Session session = null;
		int curTimes = 0;
		Properties sessionProperties = newProperties();
		for (MailEntity entity : entities) {
			try {
				if (null == transport || !transport.isConnected() || curTimes == transportUsingTimes) {
					
					if (null != transport && transport.isConnected()) {
						transport.close();
					}
					
					session = Session.getInstance(sessionProperties);
					transport = session.getTransport(new URLName("smtp",
							smtpHost, smtpPort, null, 
							from.getAddress(), password)); // 获取新的transport
					curTimes = 0;
					transport.connect(); 
					
					if (LOG.isDebugEnabled()) {
						LOG.debug("method:send,desc:transport connect,thread:" 
								+ Thread.currentThread().toString());
					}
				}
				
				MimeMessage message = new MimeMessage(session);
				
				if (null != from) {
					message.setFrom(from);
				}
				
				if (null != entity.getTo()) {
					message.addRecipient(RecipientType.TO, entity.getTo());
				}
				
				if (null != entity.getCcTo()) {
					message.addRecipient(RecipientType.CC, entity.getCcTo());
				}
				
				if (null != entity.getBccTo()) {
					message.addRecipient(RecipientType.BCC, entity.getBccTo());
				}

				if (null != subject) {
					message.setSubject(subject);
				}
				
				if (null != entity.getContent()) {
					message.setText(entity.getContent(), charset);
				} else if (null != content) {
					message.setText(content, charset);
				} 
				
				// 发送邮件
				transport.sendMessage(message, message.getAllRecipients());
				
				// 成功发送
				entity.setSuccess(true);
				
				if (LOG.isInfoEnabled()) {
					LOG.info("method:send,result:success,entity:[" 
							+ entity.toSimpleLog() + "],thread:" 
							+ Thread.currentThread().toString());
				}
				
			} catch (Exception e) {
				// 发送失败
				entity.setSuccess(false);
				
				if (LOG.isDebugEnabled()) {
					LOG.debug("method:send,result:fail,entity:[" + entity.toSimpleLog()
							+ "],thread:" + Thread.currentThread().toString(), e);
				} else {
					LOG.info("method:send,result:fail,entity:[" + entity.toSimpleLog() 
							+ "],thread:" + Thread.currentThread().toString() 
							+ ",e:" + e.getMessage());
				}
				
			} finally {
				curTimes++;
			}
		}
		
		if (null != transport && transport.isConnected()) {
			try {
				transport.close();
			} catch (Exception e) {
				LOG.info("method:send,desc:close transport,thread:" 
						+ Thread.currentThread().toString() , e);
			}
		}
		
		return entities;
		
	}
	
	/**
	 * 发送一封邮件
	 * @param entity 邮件实体
	 */
	public final MailEntity send(MailEntity entity) {
		List<MailEntity> entities = new ArrayList<MailEntity>();
		entities.add(entity);
		send(entities);
		return entity;
	}
	
	private Properties newProperties() {
		String address = from.getAddress();
		String host = address;
		
		int atIndex = address.indexOf('@');
		if (-1 != atIndex) {
			host = address.substring(atIndex + 1);
		}
		
		Properties props = new Properties();
		props.put("mail.smtp.localhost", host);
		props.put("mail.from", address);
		
		props.put("mail.debug", mailDebug);
		
		if (null != password) {
			props.put("mail.smtp.auth", true);
		}
		
		return props;
	}
}


public final class ThreadSender extends AbstractSender {
	private static final Log LOG = LogFactory.getLog(ThreadSender.class);
	
	private class Worker implements Runnable {
		private CountDownLatch latch;
		private List<MailEntity> entities;
		
		public Worker(CountDownLatch latch, List<MailEntity> entities) {
			this.latch = latch;
			this.entities = entities;
		}
		
		public void run() {
			new SingleSender()
				.charset(charset)
				.smtpHost(smtpHost)
				.smtpPort(smtpPort)
				.transportUsingTimes(transportUsingTimes)
				.from(from)
				.password(password)
				.subject(subject)
				.content(content)
				.send(entities);
			
			if (LOG.isDebugEnabled()) {
				LOG.debug("method:Worker$run,desc:latch count down,thread:" 
						+ Thread.currentThread().toString());
			}
			
			latch.countDown();
		}
		
	}
	
	/** 默认的最大线程数量:Runtime.getRuntime().availableProcessors() * 2 */
	public static final int DEFAULT_THREAD_SIZE = Runtime.getRuntime().availableProcessors() * 2;
	
	/** 每个线程单次任务默认的发送邮件量 */
	public static final int DEFAULT_TASK_OF_EACH_THREAD = 10;
	
	/** 线程数量 */
	private int threadSize = DEFAULT_THREAD_SIZE;
	
	/** 每个线程单次任务最大发送邮件量 */
	private int taskOfEachThread = DEFAULT_TASK_OF_EACH_THREAD;
	
	private ExecutorService executorService;
	
	/** 是否在发送邮件后关闭executorService,若executorService由外界传入就不关闭,否则关闭 */
	private boolean shutdownExecutor = false;
	
	public ThreadSender() {
		this(Executors.newCachedThreadPool());
		this.shutdownExecutor = true;
	}

	public ThreadSender(ExecutorService executorService) {
		this.executorService = executorService;
	}

	/**
	 * 发送一批邮件
	 * @param entities 邮件实体列表
	 */
	public final List<MailEntity> send(List<MailEntity> entities) {
		int entitiesSize = null != entities ? entities.size() : 0;
		if (LOG.isInfoEnabled()) {
			LOG.info("method:send,entitiesSize:" + entitiesSize
					+ ",thread:" + Thread.currentThread().toString());
		}
		
		if (entitiesSize < 1) {
			return entities;
		}
		
		int maxTask = threadSize * taskOfEachThread;
		if (entities.size() > maxTask) {
			throw new BatchSendException("Too much task, max is " + maxTask 
					+ ", current is " + entities.size());
		}
		
		int i = 0, toIndex = 0, len = entities.size();
		int count = len / taskOfEachThread;
		if (len % taskOfEachThread > 0) {
			count++;
		}
		CountDownLatch latch = new CountDownLatch(count);
		int realCount = 0;
		while (i < len) {
			if (LOG.isDebugEnabled()) {
				LOG.debug("method:send,desc:split task,taskIndex:" + realCount);
			}
			toIndex = i + taskOfEachThread;
			if (toIndex > len) {
				toIndex = len;
			}
			executorService.execute(new Worker(latch, entities.subList(i, toIndex)));
			i = toIndex;
			
			realCount++;
		}
		
		if (count != realCount) {
			throw new BatchSendException("Unexpected error[count != realCount], count is " 
					+ count + ", realCount is " + realCount);
		}
		
		try {
			if (LOG.isDebugEnabled()) {
				LOG.debug("method:send,desc:begin await,thread:" 
						+ Thread.currentThread().toString());
			}
			latch.await();
			if (LOG.isDebugEnabled()) {
				LOG.debug("method:send,desc:finish await,thread:" 
						+ Thread.currentThread().toString());
			}
		} catch (InterruptedException e) {
			throw new BatchSendException(e);
		}
		
		if (shutdownExecutor) {
			if (LOG.isDebugEnabled()) {
				LOG.debug("method:send,desc:try shutdown executorService,thread:" 
						+ Thread.currentThread().toString());
			}
			executorService.shutdownNow();
		}
		
		return entities;
	}
	
	public ThreadSender taskOfEachThread(int taskOfEachThread) {
		this.taskOfEachThread = taskOfEachThread;
		return this;
	}
	
	public ThreadSender threadSize(int threadSize) {
		this.threadSize = threadSize;
		return this;
	}
}
分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics