diff --git a/cores/server.go b/cores/server.go index ed7ba146a..eb490e139 100644 --- a/cores/server.go +++ b/cores/server.go @@ -176,6 +176,11 @@ func (s *Server) ServeGOB(addr string, shdChan *utils.SyncedChan) { func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() w.Header().Set("Content-Type", "application/json") + if origin := r.Header.Get("Origin"); origin != "" { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") + w.Header().Set("Access-Control-Allow-Headers", "Accept, Accept-Language, Content-Type") + } rmtIP, _ := utils.GetRemoteIP(r) rmtAddr, _ := net.ResolveIPAddr(utils.EmptyString, rmtIP) res := newRPCRequest(s.rpcSrv, r.Body, rmtAddr, s.caps, s.anz).Call() diff --git a/cores/server_test.go b/cores/server_test.go index 3493d1a76..15c289d7d 100644 --- a/cores/server_test.go +++ b/cores/server_test.go @@ -19,9 +19,11 @@ along with this program. If not, see package cores import ( + "bytes" "io" "log" "net/http" + "net/http/httptest" "os" "reflect" "testing" @@ -103,3 +105,37 @@ func TestRegisterProfiler(t *testing.T) { rcv.StopBiRPC() } + +func TestHandleRequestCORSHeaders(t *testing.T) { + caps := engine.NewCaps(0, utils.MetaBusy) + rcv := NewServer(caps) + + rcv.rpcEnabled = true + + req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:2080/jsonrpc", + bytes.NewBuffer([]byte("1"))) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Origin", "http://origin.com") + + w := httptest.NewRecorder() + + rcv.handleRequest(w, req) + + if origin := req.Header.Get("Origin"); origin != "" { + if got := w.Header().Get("Access-Control-Allow-Origin"); got != origin { + t.Errorf("Expected <%v>, got <%v>", "http://origin.com", got) + } + } + + expectedMethods := "POST, GET, OPTIONS, PUT, DELETE" + if got := w.Header().Get("Access-Control-Allow-Methods"); got != expectedMethods { + t.Errorf("Expected <%v>; got <%v>", expectedMethods, got) + } + + expectedHeaders := "Accept, Accept-Language, Content-Type" + if got := w.Header().Get("Access-Control-Allow-Headers"); got != expectedHeaders { + t.Errorf("Expected <%v>; got <%v>", expectedHeaders, got) + } +}