兩小時手寫springmvc框架

luffy5459發表於2018-08-15

這篇文章是學習咕泡學院tom老師手寫一個spring框架視訊而來,程式碼基本複製原文,稍作修改。可以幫助我們理解springmvc實現的大致原理。

1、構建maven工程,只需要匯入javax.servlet-api的依賴。另外配置,直接通過tomcat外掛來啟動專案。

  <dependencies>
    <dependency>
      <groupId>javax.servlet</groupId>
      <artifactId>javax.servlet-api</artifactId>
      <version>3.0.1</version>
      <scope>provided</scope>
    </dependency>
  </dependencies>
  <build>
    <finalName>springmvc</finalName>
    <plugins>
        <plugin>
              <groupId>org.eclipse.jetty</groupId>
              <artifactId>jetty-maven-plugin</artifactId>
              <version>9.4.7.v20170914</version>
              <configuration>
                  <webApp>
                      <contextPath>/${project.build.finalName}</contextPath>
                  </webApp>
                  <stopKey>CTRL+C</stopKey>
                  <stopPort>8999</stopPort>
                  <scanIntervalSeconds>10</scanIntervalSeconds>
                  <scanTargets>
                      <scanTarget>src/main/webapp/WEB-INF/web.xml</scanTarget>
                  </scanTargets>
              </configuration>
        </plugin>
    </plugins>
  </build>

專案結構:

2、編寫主要的程式碼並配置web.xml

XXAutowired.java

package com.xxx.springmvc.annotation;

import java.lang.annotation.*;
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXAutowired {
	String value() default "";
}

XXController.java

package com.xxx.springmvc.annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXController {
	String value() default "";
}

XXRequestMapping.java

package com.xxx.springmvc.annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXRequestMapping {
	String value() default "";
}

XXRequestParam.java

package com.xxx.springmvc.annotation;
import java.lang.annotation.*;
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXRequestParam {
	String value() default "";
}

XXService.java

package com.xxx.springmvc.annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXService {
	String value() default "";
}

實體類User.java

package com.xxx.springmvc.entity;

public class User {
	private String id;
	private String username;
	private String password;
	...此處省略get set方法
	@Override
	public String toString() {
		return "{\"id\":" + id + ", \"username\":\"" + username + "\", \"password\":\""
				+ password + "\"}";
	}
	public User() {}
	public User(String id, String username, String password) {
		this.id = id;
		this.username = username;
		this.password = password;
	}
	
}

UserService.java介面檔案

package com.xxx.springmvc.web.service;

import java.util.List;

import com.xxx.springmvc.entity.User;

public interface UserService {
	String get(String name);
	List<User> list();
}

UserServiceImpl.java

package com.xxx.springmvc.web.service.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import com.xxx.springmvc.annotation.XXService;
import com.xxx.springmvc.entity.User;
import com.xxx.springmvc.web.service.UserService;
@XXService("userService")
public class UserServiceImpl implements UserService {
	
	private static Map<String,User> users = new HashMap<String, User>();
	
	static{
		users.put("aa", new User("1","aaa","123456"));
		users.put("bb", new User("2","bbb","123456"));
		users.put("cc", new User("3","ccc","123456"));
		users.put("dd", new User("4","ddd","123456"));
		users.put("ee", new User("5","eee","123456"));
	}

	@Override
	public String get(String name) {
		User user = users.get(name);
		if(user==null){
			user = users.get("aa");
		}
		return user.toString();
	}

	@Override
	public List<User> list() {
		List<User> list = new ArrayList<User>();
		for(Entry<String, User> entry : users.entrySet()){
			list.add(entry.getValue());
		}
		return list;
	}

}

UserController.java

package com.xxx.springmvc.web.controller;
import java.io.IOException;
import java.util.List;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.xxx.springmvc.annotation.XXAutowired;
import com.xxx.springmvc.annotation.XXController;
import com.xxx.springmvc.annotation.XXRequestMapping;
import com.xxx.springmvc.annotation.XXRequestParam;
import com.xxx.springmvc.entity.User;
@XXController
@XXRequestMapping("/user")
public class UserController {
	
	@XXAutowired
	private com.xxx.springmvc.web.service.UserService userService;
	
	@XXRequestMapping("/index")
	public String index(HttpServletRequest request,HttpServletResponse response,
          @XXRequestParam("name")String name) throws IOException{
		String res = userService.get(name);
		System.out.println(name+"=>"+res);
		response.setContentType("application/json;charset=UTF-8");
		response.getWriter().write(res);
		return "index";
	}
	
	@XXRequestMapping("/list")
	public String list(HttpServletRequest request,HttpServletResponse response)
       throws IOException{
		List<User> users = userService.list();
		response.setContentType("application/json;charset=UTF-8");
		response.getWriter().write(users.toString());
		return "list";
	}
}

核心類:XXDispatcherServlet.java

package com.xxx.springmvc.servlet;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.xxx.springmvc.annotation.XXAutowired;
import com.xxx.springmvc.annotation.XXController;
import com.xxx.springmvc.annotation.XXRequestMapping;
import com.xxx.springmvc.annotation.XXRequestParam;
import com.xxx.springmvc.annotation.XXService;

public class XXDispatcherServlet extends HttpServlet{
	private Properties contextConfig = new Properties();
	private List<String> classNames = new ArrayList<String>();
	private Map<String, Object> ioc = new HashMap<String, Object>();
	private List<Handler> handlerMapping = new ArrayList<Handler>();

	private static final long serialVersionUID = -4943120355864715254L;

	
	@Override
	public void init(ServletConfig config) throws ServletException {
		//load config
		doLoadConfig(config.getInitParameter("contextConfigLocation"));
		//scan relative class
		doScanner(contextConfig.getProperty("scanPackage"));
		//init ioc container put relative class to it
		doInstance();
		//inject dependence
		doAutoWired();
		//init handlerMapping
		initHandlerMapping();
	}

	private void initHandlerMapping() {
		if(ioc.isEmpty())return;
		for(Map.Entry<String, Object> entry:ioc.entrySet()){
			Class<?> clazz = entry.getValue().getClass();
			if(!clazz.isAnnotationPresent(XXController.class)){continue;}
			String baseUrl = "";
			if(clazz.isAnnotationPresent(XXRequestMapping.class)){
				XXRequestMapping requestMapping = clazz.getAnnotation(XXRequestMapping.class);
				baseUrl = requestMapping.value();
			}
			Method[] methods = clazz.getMethods();
			for(Method method:methods){
				if(!method.isAnnotationPresent(XXRequestMapping.class)){continue;}
				XXRequestMapping requestMapping = method.getAnnotation(XXRequestMapping.class);
				String url = (baseUrl+requestMapping.value()).replaceAll("/+", "/");
				Pattern pattern = Pattern.compile(url);
				handlerMapping.add(new Handler(pattern, entry.getValue(), method));
				System.out.println("mapped:"+url+"=>"+method);
			}
		}
	}

	private void doAutoWired() {
		if(ioc.isEmpty())return;
		for(Map.Entry<String, Object> entry:ioc.entrySet()){
			//依賴注入->給加了XXAutowired註解的欄位賦值
			Field[] fields = entry.getValue().getClass().getDeclaredFields();
			for(Field field:fields){
				if(!field.isAnnotationPresent(XXAutowired.class)){continue;}
				XXAutowired autowired = field.getAnnotation(XXAutowired.class);
				String beanName = autowired.value();
				if("".equals(beanName)){
					beanName = field.getType().getName();
				}
				field.setAccessible(true);
				try {
					field.set(entry.getValue(), ioc.get(beanName));
				} catch (IllegalAccessException e) {
					e.printStackTrace();
					continue;
				}
			}
		}
	}

	private void doInstance() {
		if(classNames.isEmpty())return;
		try {		
			for(String className:classNames){
				Class<?> clazz = Class.forName(className);
				if(clazz.isAnnotationPresent(XXController.class)){
					String beanName = lowerFirstCase(clazz.getSimpleName());
					ioc.put(beanName, clazz.newInstance());
				}else if(clazz.isAnnotationPresent(XXService.class)){
					
					XXService service = clazz.getAnnotation(XXService.class);
					String beanName = service.value();
					if("".equals(beanName)){
						beanName = lowerFirstCase(clazz.getSimpleName());
					}
					Object instance = clazz.newInstance();
					ioc.put(beanName, instance);
					Class<?>[] interfaces = clazz.getInterfaces();
					for(Class<?> i:interfaces){
						ioc.put(i.getName(), instance);
					}
				}else{
					continue;
				}
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	private void doScanner(String packageName) {
		URL resource = 
this.getClass().getClassLoader().getResource("/"+packageName.replaceAll("\\.", "/"));
	    File classDir = new File(resource.getFile());
	    for(File classFile:classDir.listFiles()){
	    	if(classFile.isDirectory()){
	    		doScanner(packageName+"."+classFile.getName());
	    	}else{
	    		String className = (packageName+"."+classFile.getName()).replace(".class", "");
	    		classNames.add(className);
	    	}
	    }
	}

	private void doLoadConfig(String location) {
		InputStream input = this.getClass().getClassLoader().getResourceAsStream(location);
		try {
			contextConfig.load(input);
		} catch (IOException e) {
			e.printStackTrace();
		}finally{
			if(input!=null){
				try {
					input.close();
				} catch (IOException e) {
					e.printStackTrace();
				}
			}
		}
	}

	@Override
	protected void doGet(HttpServletRequest req, HttpServletResponse res)
			throws ServletException, IOException {
		this.doPost(req, res);
	}

	@Override
	protected void doPost(HttpServletRequest req, HttpServletResponse res)
			throws ServletException, IOException {
		doDispatcher(req, res);
	}
	
	public void doDispatcher(HttpServletRequest req,HttpServletResponse res){
		try {
			Handler handler = getHandler(req);
			if(handler==null){
				res.getWriter().write("404 not found.");
				return;
			}
			Class<?>[] paramTypes = handler.method.getParameterTypes();
			Object[] paramValues = new Object[paramTypes.length];
			Map<String, String[]> params = req.getParameterMap();
			for(Entry<String, String[]> param:params.entrySet()){
				String value = Arrays.toString(param.getValue()).replaceAll("\\[|\\]", "");
				if(!handler.paramIndexMapping.containsKey(param.getKey())){continue;}
				int index = handler.paramIndexMapping.get(param.getKey());
				paramValues[index] = convert(paramTypes[index],value);
			}
			int reqIndex = handler.paramIndexMapping.get(HttpServletRequest.class.getName());
			paramValues[reqIndex] = req;
			int resIndex = handler.paramIndexMapping.get(HttpServletResponse.class.getName());
			paramValues[resIndex] = res;
			handler.method.invoke(handler.controller, paramValues);
		} catch (Exception e) {
			e.printStackTrace();
		}
		String url = req.getRequestURI();
		String contextPath = req.getContextPath();
		url = url.replace(contextPath, "").replaceAll("/+", "/");
		
	}
	private Object convert(Class<?> type, String value) {
		if(Integer.class == type){
			return Integer.valueOf(value);
		}
		return value;
	}

	private String lowerFirstCase(String str){
		char[] chars = str.toCharArray();
		chars[0] += 32;
		return String.valueOf(chars);
	}
	private Handler getHandler(HttpServletRequest req){
		if(handlerMapping.isEmpty()){return null;}
		String url = req.getRequestURI();
		String contextPath = req.getContextPath();
		url = url.replace(contextPath, "").replaceAll("/+", "/");
		for(Handler handler:handlerMapping){
			Matcher matcher = handler.pattern.matcher(url);
			if(!matcher.matches()){continue;}
			return handler;
		}
		return null;
	}
	private class Handler{
		protected Object controller;
		protected Method method;
		protected Pattern pattern;
		protected Map<String, Integer> paramIndexMapping;
		protected Handler(Pattern pattern,Object controller,Method method){
			this.pattern = pattern;
			this.controller = controller;
			this.method = method;
			paramIndexMapping = new HashMap<String, Integer>();
			putParamIndexMapping(method);
		}
		private void putParamIndexMapping(Method method) {
			Annotation[][] pa = method.getParameterAnnotations();
			for(int i=0;i<pa.length;i++){
				for(Annotation a:pa[i]){
					if(a instanceof XXRequestParam){
						String paramName = ((XXRequestParam)a).value();
						if(!"".equals(paramName)){
							paramIndexMapping.put(paramName, i);
						}
					}
				}
			}
			Class<?>[] paramTypes = method.getParameterTypes();
			for(int i=0;i<paramTypes.length;i++){
				Class<?> type = paramTypes[i];
				if(type == HttpServletRequest.class || type == HttpServletResponse.class){
					paramIndexMapping.put(type.getName(), i);
				}
			}
		}
	}
}

配置檔案config.properties

scanPackage=com.xxx.springmvc.web

配置檔案web.xml

<servlet>
      <servlet-name>springmvc</servlet-name>
      <servlet-class>com.xxx.springmvc.servlet.XXDispatcherServlet</servlet-class>
      <init-param>
           <param-name>contextConfigLocation</param-name>
           <param-value>config.properties</param-value>
      </init-param>
      <load-on-startup>1</load-on-startup>
  </servlet>
  <servlet-mapping>
      <servlet-name>springmvc</servlet-name>
      <url-pattern>/</url-pattern>
  </servlet-mapping>

3、執行:

控制檯列印日誌:

訪問 http://localhost:8080/springmvc/user/index?name=bb:

訪問 http://localhost:8080/springmvc/user/list:

訪問 http://localhost:8080/springmvc/user/detail 出現404:

相關文章