mirror of
https://github.com/fatedier/frp.git
synced 2025-07-31 23:31:33 +07:00
start refactoring
This commit is contained in:
160
models/config/client_common.go
Normal file
160
models/config/client_common.go
Normal file
@ -0,0 +1,160 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
ini "github.com/vaughan0/go-ini"
|
||||
)
|
||||
|
||||
var ClientCommonCfg *ClientCommonConf
|
||||
|
||||
// client common config
|
||||
type ClientCommonConf struct {
|
||||
ConfigFile string
|
||||
ServerAddr string
|
||||
ServerPort int64
|
||||
HttpProxy string
|
||||
LogFile string
|
||||
LogWay string
|
||||
LogLevel string
|
||||
LogMaxDays int64
|
||||
PrivilegeToken string
|
||||
PoolCount int
|
||||
User string
|
||||
HeartBeatInterval int64
|
||||
HeartBeatTimeout int64
|
||||
}
|
||||
|
||||
func GetDeaultClientCommonConf() *ClientCommonConf {
|
||||
return &ClientCommonConf{
|
||||
ConfigFile: "./frpc.ini",
|
||||
ServerAddr: "0.0.0.0",
|
||||
ServerPort: 7000,
|
||||
HttpProxy: "",
|
||||
LogFile: "console",
|
||||
LogWay: "console",
|
||||
LogLevel: "info",
|
||||
LogMaxDays: 3,
|
||||
PrivilegeToken: "",
|
||||
PoolCount: 1,
|
||||
User: "",
|
||||
HeartBeatInterval: 10,
|
||||
HeartBeatTimeout: 30,
|
||||
}
|
||||
}
|
||||
|
||||
func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
v int64
|
||||
)
|
||||
cfg = GetDeaultClientCommonConf()
|
||||
|
||||
tmpStr, ok = conf.Get("common", "server_addr")
|
||||
if ok {
|
||||
cfg.ServerAddr = tmpStr
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "server_port")
|
||||
if ok {
|
||||
cfg.ServerPort, _ = strconv.ParseInt(tmpStr, 10, 64)
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "http_proxy")
|
||||
if ok {
|
||||
cfg.HttpProxy = tmpStr
|
||||
} else {
|
||||
// get http_proxy from env
|
||||
cfg.HttpProxy = os.Getenv("http_proxy")
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "log_file")
|
||||
if ok {
|
||||
cfg.LogFile = tmpStr
|
||||
if cfg.LogFile == "console" {
|
||||
cfg.LogWay = "console"
|
||||
} else {
|
||||
cfg.LogWay = "file"
|
||||
}
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "log_level")
|
||||
if ok {
|
||||
cfg.LogLevel = tmpStr
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "log_max_days")
|
||||
if ok {
|
||||
cfg.LogMaxDays, _ = strconv.ParseInt(tmpStr, 10, 64)
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "privilege_token")
|
||||
if ok {
|
||||
cfg.PrivilegeToken = tmpStr
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "pool_count")
|
||||
if ok {
|
||||
v, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err != nil {
|
||||
cfg.PoolCount = 1
|
||||
} else {
|
||||
cfg.PoolCount = int(v)
|
||||
}
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "user")
|
||||
if ok {
|
||||
cfg.User = tmpStr
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "heartbeat_timeout")
|
||||
if ok {
|
||||
v, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect")
|
||||
return
|
||||
} else {
|
||||
cfg.HeartBeatTimeout = v
|
||||
}
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "heartbeat_interval")
|
||||
if ok {
|
||||
v, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect")
|
||||
return
|
||||
} else {
|
||||
cfg.HeartBeatInterval = v
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.HeartBeatInterval <= 0 {
|
||||
err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect")
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.HeartBeatTimeout < cfg.HeartBeatInterval {
|
||||
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect, heartbeat_timeout is less than heartbeat_interval")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
446
models/config/proxy.go
Normal file
446
models/config/proxy.go
Normal file
@ -0,0 +1,446 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fatedier/frp/models/consts"
|
||||
"github.com/fatedier/frp/models/msg"
|
||||
|
||||
ini "github.com/vaughan0/go-ini"
|
||||
)
|
||||
|
||||
type ProxyConf interface {
|
||||
GetName() string
|
||||
GetBaseInfo() *BaseProxyConf
|
||||
LoadFromMsg(pMsg *msg.NewProxy)
|
||||
LoadFromFile(name string, conf ini.Section) error
|
||||
UnMarshalToMsg(pMsg *msg.NewProxy)
|
||||
Check() error
|
||||
}
|
||||
|
||||
func NewProxyConf(pMsg *msg.NewProxy) (cfg ProxyConf, err error) {
|
||||
if pMsg.ProxyType == "" {
|
||||
pMsg.ProxyType = consts.TcpProxy
|
||||
}
|
||||
switch pMsg.ProxyType {
|
||||
case consts.TcpProxy:
|
||||
cfg = &TcpProxyConf{}
|
||||
case consts.UdpProxy:
|
||||
cfg = &UdpProxyConf{}
|
||||
case consts.HttpProxy:
|
||||
cfg = &HttpProxyConf{}
|
||||
case consts.HttpsProxy:
|
||||
cfg = &HttpsProxyConf{}
|
||||
default:
|
||||
err = fmt.Errorf("proxy [%s] type [%s] error", pMsg.ProxyName, pMsg.ProxyType)
|
||||
return
|
||||
}
|
||||
cfg.LoadFromMsg(pMsg)
|
||||
err = cfg.Check()
|
||||
return
|
||||
}
|
||||
|
||||
func NewProxyConfFromFile(name string, section ini.Section) (cfg ProxyConf, err error) {
|
||||
proxyType := section["type"]
|
||||
if proxyType == "" {
|
||||
proxyType = consts.TcpProxy
|
||||
section["type"] = consts.TcpProxy
|
||||
}
|
||||
switch proxyType {
|
||||
case consts.TcpProxy:
|
||||
cfg = &TcpProxyConf{}
|
||||
case consts.UdpProxy:
|
||||
cfg = &UdpProxyConf{}
|
||||
case consts.HttpProxy:
|
||||
cfg = &HttpProxyConf{}
|
||||
case consts.HttpsProxy:
|
||||
cfg = &HttpsProxyConf{}
|
||||
default:
|
||||
err = fmt.Errorf("proxy [%s] type [%s] error", name, proxyType)
|
||||
return
|
||||
}
|
||||
err = cfg.LoadFromFile(name, section)
|
||||
return
|
||||
}
|
||||
|
||||
// BaseProxy info
|
||||
type BaseProxyConf struct {
|
||||
ProxyName string
|
||||
ProxyType string
|
||||
|
||||
UseEncryption bool
|
||||
UseCompression bool
|
||||
}
|
||||
|
||||
func (cfg *BaseProxyConf) GetName() string {
|
||||
return cfg.ProxyName
|
||||
}
|
||||
|
||||
func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (cfg *BaseProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
|
||||
cfg.ProxyName = pMsg.ProxyName
|
||||
cfg.ProxyType = pMsg.ProxyType
|
||||
cfg.UseEncryption = pMsg.UseEncryption
|
||||
cfg.UseCompression = pMsg.UseCompression
|
||||
}
|
||||
|
||||
func (cfg *BaseProxyConf) LoadFromFile(name string, section ini.Section) error {
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
)
|
||||
cfg.ProxyName = ClientCommonCfg.User + "." + name
|
||||
cfg.ProxyType = section["type"]
|
||||
|
||||
tmpStr, ok = section["use_encryption"]
|
||||
if ok && tmpStr == "true" {
|
||||
cfg.UseEncryption = true
|
||||
}
|
||||
|
||||
tmpStr, ok = section["use_compression"]
|
||||
if ok && tmpStr == "true" {
|
||||
cfg.UseCompression = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *BaseProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
|
||||
pMsg.ProxyName = cfg.ProxyName
|
||||
pMsg.ProxyType = cfg.ProxyType
|
||||
pMsg.UseEncryption = cfg.UseEncryption
|
||||
pMsg.UseCompression = cfg.UseCompression
|
||||
}
|
||||
|
||||
// Bind info
|
||||
type BindInfoConf struct {
|
||||
BindAddr string
|
||||
RemotePort int64
|
||||
}
|
||||
|
||||
func (cfg *BindInfoConf) LoadFromMsg(pMsg *msg.NewProxy) {
|
||||
cfg.BindAddr = ServerCommonCfg.BindAddr
|
||||
cfg.RemotePort = pMsg.RemotePort
|
||||
}
|
||||
|
||||
func (cfg *BindInfoConf) LoadFromFile(name string, section ini.Section) (err error) {
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
)
|
||||
if tmpStr, ok = section["remote_port"]; ok {
|
||||
if cfg.RemotePort, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
return fmt.Errorf("Parse conf error: proxy [%s] remote_port error", name)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *BindInfoConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
|
||||
pMsg.RemotePort = cfg.RemotePort
|
||||
}
|
||||
|
||||
func (cfg *BindInfoConf) check() (err error) {
|
||||
if len(ServerCommonCfg.PrivilegeAllowPorts) != 0 {
|
||||
if _, ok := ServerCommonCfg.PrivilegeAllowPorts[cfg.RemotePort]; !ok {
|
||||
return fmt.Errorf("remote port [%d] isn't allowed", cfg.RemotePort)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Domain info
|
||||
type DomainConf struct {
|
||||
CustomDomains []string
|
||||
SubDomain string
|
||||
}
|
||||
|
||||
func (cfg *DomainConf) LoadFromMsg(pMsg *msg.NewProxy) {
|
||||
cfg.CustomDomains = pMsg.CustomDomains
|
||||
cfg.SubDomain = pMsg.SubDomain
|
||||
}
|
||||
|
||||
func (cfg *DomainConf) LoadFromFile(name string, section ini.Section) (err error) {
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
)
|
||||
if tmpStr, ok = section["custom_domains"]; ok {
|
||||
cfg.CustomDomains = strings.Split(tmpStr, ",")
|
||||
for i, domain := range cfg.CustomDomains {
|
||||
cfg.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain))
|
||||
}
|
||||
}
|
||||
|
||||
if tmpStr, ok = section["subdomain"]; ok {
|
||||
cfg.SubDomain = tmpStr
|
||||
}
|
||||
|
||||
if len(cfg.CustomDomains) == 0 && cfg.SubDomain == "" {
|
||||
return fmt.Errorf("Parse conf error: proxy [%s] custom_domains and subdomain should set at least one of them", name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *DomainConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
|
||||
pMsg.CustomDomains = cfg.CustomDomains
|
||||
pMsg.SubDomain = cfg.SubDomain
|
||||
}
|
||||
|
||||
func (cfg *DomainConf) check() (err error) {
|
||||
for _, domain := range cfg.CustomDomains {
|
||||
if ServerCommonCfg.SubDomainHost != "" && len(strings.Split(ServerCommonCfg.SubDomainHost, ".")) < len(strings.Split(domain, ".")) {
|
||||
if strings.Contains(domain, ServerCommonCfg.SubDomainHost) {
|
||||
return fmt.Errorf("custom domain [%s] should not belong to subdomain_host [%s]", domain, ServerCommonCfg.SubDomainHost)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.SubDomain != "" {
|
||||
if ServerCommonCfg.SubDomainHost == "" {
|
||||
return fmt.Errorf("subdomain is not supported because this feature is not enabled by frps")
|
||||
}
|
||||
if strings.Contains(cfg.SubDomain, ".") || strings.Contains(cfg.SubDomain, "*") {
|
||||
return fmt.Errorf("'.' and '*' is not supported in subdomain")
|
||||
}
|
||||
cfg.SubDomain += "." + ServerCommonCfg.SubDomainHost
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type LocalSvrConf struct {
|
||||
LocalIp string
|
||||
LocalPort int
|
||||
}
|
||||
|
||||
func (cfg *LocalSvrConf) LoadFromFile(name string, section ini.Section) (err error) {
|
||||
if cfg.LocalIp = section["local_ip"]; cfg.LocalIp == "" {
|
||||
cfg.LocalIp = "127.0.0.1"
|
||||
}
|
||||
|
||||
if tmpStr, ok := section["local_port"]; ok {
|
||||
if cfg.LocalPort, err = strconv.Atoi(tmpStr); err != nil {
|
||||
return fmt.Errorf("Parse conf error: proxy [%s] local_port error", name)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("Parse conf error: proxy [%s] local_port not found", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TCP
|
||||
type TcpProxyConf struct {
|
||||
BaseProxyConf
|
||||
BindInfoConf
|
||||
|
||||
LocalSvrConf
|
||||
}
|
||||
|
||||
func (cfg *TcpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
|
||||
cfg.BaseProxyConf.LoadFromMsg(pMsg)
|
||||
cfg.BindInfoConf.LoadFromMsg(pMsg)
|
||||
}
|
||||
|
||||
func (cfg *TcpProxyConf) LoadFromFile(name string, section ini.Section) (err error) {
|
||||
if err = cfg.BaseProxyConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
if err = cfg.BindInfoConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *TcpProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
|
||||
cfg.BaseProxyConf.UnMarshalToMsg(pMsg)
|
||||
cfg.BindInfoConf.UnMarshalToMsg(pMsg)
|
||||
}
|
||||
|
||||
func (cfg *TcpProxyConf) Check() (err error) {
|
||||
err = cfg.BindInfoConf.check()
|
||||
return
|
||||
}
|
||||
|
||||
// UDP
|
||||
type UdpProxyConf struct {
|
||||
BaseProxyConf
|
||||
BindInfoConf
|
||||
|
||||
LocalSvrConf
|
||||
}
|
||||
|
||||
func (cfg *UdpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
|
||||
cfg.BaseProxyConf.LoadFromMsg(pMsg)
|
||||
cfg.BindInfoConf.LoadFromMsg(pMsg)
|
||||
}
|
||||
|
||||
func (cfg *UdpProxyConf) LoadFromFile(name string, section ini.Section) (err error) {
|
||||
if err = cfg.BaseProxyConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
if err = cfg.BindInfoConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *UdpProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
|
||||
cfg.BaseProxyConf.UnMarshalToMsg(pMsg)
|
||||
cfg.BindInfoConf.UnMarshalToMsg(pMsg)
|
||||
}
|
||||
|
||||
func (cfg *UdpProxyConf) Check() (err error) {
|
||||
err = cfg.BindInfoConf.check()
|
||||
return
|
||||
}
|
||||
|
||||
// HTTP
|
||||
type HttpProxyConf struct {
|
||||
BaseProxyConf
|
||||
DomainConf
|
||||
|
||||
LocalSvrConf
|
||||
|
||||
Locations []string
|
||||
HostHeaderRewrite string
|
||||
HttpUser string
|
||||
HttpPwd string
|
||||
}
|
||||
|
||||
func (cfg *HttpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
|
||||
cfg.BaseProxyConf.LoadFromMsg(pMsg)
|
||||
cfg.DomainConf.LoadFromMsg(pMsg)
|
||||
|
||||
cfg.Locations = pMsg.Locations
|
||||
cfg.HostHeaderRewrite = pMsg.HostHeaderRewrite
|
||||
cfg.HttpUser = pMsg.HttpUser
|
||||
cfg.HttpPwd = pMsg.HttpPwd
|
||||
}
|
||||
|
||||
func (cfg *HttpProxyConf) LoadFromFile(name string, section ini.Section) (err error) {
|
||||
if err = cfg.BaseProxyConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
if err = cfg.DomainConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
)
|
||||
if tmpStr, ok = section["locations"]; ok {
|
||||
cfg.Locations = strings.Split(tmpStr, ",")
|
||||
} else {
|
||||
cfg.Locations = []string{""}
|
||||
}
|
||||
|
||||
cfg.HostHeaderRewrite = section["host_header_rewrite"]
|
||||
cfg.HttpUser = section["http_user"]
|
||||
cfg.HttpPwd = section["http_pwd"]
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *HttpProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
|
||||
cfg.BaseProxyConf.UnMarshalToMsg(pMsg)
|
||||
cfg.DomainConf.UnMarshalToMsg(pMsg)
|
||||
|
||||
pMsg.Locations = cfg.Locations
|
||||
pMsg.HostHeaderRewrite = cfg.HostHeaderRewrite
|
||||
pMsg.HttpUser = cfg.HttpUser
|
||||
pMsg.HttpPwd = cfg.HttpPwd
|
||||
}
|
||||
|
||||
func (cfg *HttpProxyConf) Check() (err error) {
|
||||
if ServerCommonCfg.VhostHttpPort == 0 {
|
||||
return fmt.Errorf("type [http] not support when vhost_http_port is not set")
|
||||
}
|
||||
err = cfg.DomainConf.check()
|
||||
return
|
||||
}
|
||||
|
||||
// HTTPS
|
||||
type HttpsProxyConf struct {
|
||||
BaseProxyConf
|
||||
DomainConf
|
||||
|
||||
LocalSvrConf
|
||||
}
|
||||
|
||||
func (cfg *HttpsProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
|
||||
cfg.BaseProxyConf.LoadFromMsg(pMsg)
|
||||
cfg.DomainConf.LoadFromMsg(pMsg)
|
||||
}
|
||||
|
||||
func (cfg *HttpsProxyConf) LoadFromFile(name string, section ini.Section) (err error) {
|
||||
if err = cfg.BaseProxyConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
if err = cfg.DomainConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *HttpsProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
|
||||
cfg.BaseProxyConf.UnMarshalToMsg(pMsg)
|
||||
cfg.DomainConf.UnMarshalToMsg(pMsg)
|
||||
}
|
||||
|
||||
func (cfg *HttpsProxyConf) Check() (err error) {
|
||||
if ServerCommonCfg.VhostHttpsPort == 0 {
|
||||
return fmt.Errorf("type [https] not support when vhost_https_port is not set")
|
||||
}
|
||||
err = cfg.DomainConf.check()
|
||||
return
|
||||
}
|
||||
|
||||
func LoadProxyConfFromFile(conf ini.File) (proxyConfs map[string]ProxyConf, err error) {
|
||||
var prefix string
|
||||
if ClientCommonCfg.User != "" {
|
||||
prefix = ClientCommonCfg.User + "."
|
||||
}
|
||||
proxyConfs = make(map[string]ProxyConf)
|
||||
for name, section := range conf {
|
||||
if name != "common" {
|
||||
cfg, err := NewProxyConfFromFile(name, section)
|
||||
if err != nil {
|
||||
return proxyConfs, err
|
||||
}
|
||||
proxyConfs[prefix+name] = cfg
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
279
models/config/server_common.go
Normal file
279
models/config/server_common.go
Normal file
@ -0,0 +1,279 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
ini "github.com/vaughan0/go-ini"
|
||||
)
|
||||
|
||||
var ServerCommonCfg *ServerCommonConf
|
||||
|
||||
// common config
|
||||
type ServerCommonConf struct {
|
||||
ConfigFile string
|
||||
BindAddr string
|
||||
BindPort int64
|
||||
|
||||
// If VhostHttpPort equals 0, don't listen a public port for http protocol.
|
||||
VhostHttpPort int64
|
||||
|
||||
// if VhostHttpsPort equals 0, don't listen a public port for https protocol
|
||||
VhostHttpsPort int64
|
||||
|
||||
// if DashboardPort equals 0, dashboard is not available
|
||||
DashboardPort int64
|
||||
DashboardUser string
|
||||
DashboardPwd string
|
||||
AssetsDir string
|
||||
LogFile string
|
||||
LogWay string // console or file
|
||||
LogLevel string
|
||||
LogMaxDays int64
|
||||
PrivilegeMode bool
|
||||
PrivilegeToken string
|
||||
AuthTimeout int64
|
||||
SubDomainHost string
|
||||
|
||||
// if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected
|
||||
PrivilegeAllowPorts map[int64]struct{}
|
||||
MaxPoolCount int64
|
||||
HeartBeatTimeout int64
|
||||
UserConnTimeout int64
|
||||
}
|
||||
|
||||
func GetDefaultServerCommonConf() *ServerCommonConf {
|
||||
return &ServerCommonConf{
|
||||
ConfigFile: "./frps.ini",
|
||||
BindAddr: "0.0.0.0",
|
||||
BindPort: 7000,
|
||||
VhostHttpPort: 0,
|
||||
VhostHttpsPort: 0,
|
||||
DashboardPort: 0,
|
||||
DashboardUser: "admin",
|
||||
DashboardPwd: "admin",
|
||||
AssetsDir: "",
|
||||
LogFile: "console",
|
||||
LogWay: "console",
|
||||
LogLevel: "info",
|
||||
LogMaxDays: 3,
|
||||
PrivilegeMode: true,
|
||||
PrivilegeToken: "",
|
||||
AuthTimeout: 900,
|
||||
SubDomainHost: "",
|
||||
MaxPoolCount: 10,
|
||||
HeartBeatTimeout: 30,
|
||||
UserConnTimeout: 10,
|
||||
}
|
||||
}
|
||||
|
||||
// Load server common configure.
|
||||
func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
v int64
|
||||
)
|
||||
cfg = GetDefaultServerCommonConf()
|
||||
|
||||
tmpStr, ok = conf.Get("common", "bind_addr")
|
||||
if ok {
|
||||
cfg.BindAddr = tmpStr
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "bind_port")
|
||||
if ok {
|
||||
v, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err == nil {
|
||||
cfg.BindPort = v
|
||||
}
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "vhost_http_port")
|
||||
if ok {
|
||||
cfg.VhostHttpPort, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Parse conf error: vhost_http_port is incorrect")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
cfg.VhostHttpPort = 0
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "vhost_https_port")
|
||||
if ok {
|
||||
cfg.VhostHttpsPort, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Parse conf error: vhost_https_port is incorrect")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
cfg.VhostHttpsPort = 0
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "dashboard_port")
|
||||
if ok {
|
||||
cfg.DashboardPort, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Parse conf error: dashboard_port is incorrect")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
cfg.DashboardPort = 0
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "dashboard_user")
|
||||
if ok {
|
||||
cfg.DashboardUser = tmpStr
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "dashboard_pwd")
|
||||
if ok {
|
||||
cfg.DashboardPwd = tmpStr
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "assets_dir")
|
||||
if ok {
|
||||
cfg.AssetsDir = tmpStr
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "log_file")
|
||||
if ok {
|
||||
cfg.LogFile = tmpStr
|
||||
if cfg.LogFile == "console" {
|
||||
cfg.LogWay = "console"
|
||||
} else {
|
||||
cfg.LogWay = "file"
|
||||
}
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "log_level")
|
||||
if ok {
|
||||
cfg.LogLevel = tmpStr
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "log_max_days")
|
||||
if ok {
|
||||
v, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err == nil {
|
||||
cfg.LogMaxDays = v
|
||||
}
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "privilege_mode")
|
||||
if ok {
|
||||
if tmpStr == "true" {
|
||||
cfg.PrivilegeMode = true
|
||||
}
|
||||
}
|
||||
|
||||
// PrivilegeMode configure
|
||||
if cfg.PrivilegeMode == true {
|
||||
tmpStr, ok = conf.Get("common", "privilege_token")
|
||||
if ok {
|
||||
if tmpStr == "" {
|
||||
err = fmt.Errorf("Parse conf error: privilege_token can not be empty")
|
||||
return
|
||||
}
|
||||
cfg.PrivilegeToken = tmpStr
|
||||
} else {
|
||||
err = fmt.Errorf("Parse conf error: privilege_token must be set if privilege_mode is enabled")
|
||||
return
|
||||
}
|
||||
|
||||
cfg.PrivilegeAllowPorts = make(map[int64]struct{})
|
||||
tmpStr, ok = conf.Get("common", "privilege_allow_ports")
|
||||
if ok {
|
||||
// e.g. 1000-2000,2001,2002,3000-4000
|
||||
portRanges := strings.Split(tmpStr, ",")
|
||||
for _, portRangeStr := range portRanges {
|
||||
// 1000-2000 or 2001
|
||||
portArray := strings.Split(portRangeStr, "-")
|
||||
// length: only 1 or 2 is correct
|
||||
rangeType := len(portArray)
|
||||
if rangeType == 1 {
|
||||
// single port
|
||||
singlePort, errRet := strconv.ParseInt(portArray[0], 10, 64)
|
||||
if errRet != nil {
|
||||
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
|
||||
return
|
||||
}
|
||||
ServerCommonCfg.PrivilegeAllowPorts[singlePort] = struct{}{}
|
||||
} else if rangeType == 2 {
|
||||
// range ports
|
||||
min, errRet := strconv.ParseInt(portArray[0], 10, 64)
|
||||
if errRet != nil {
|
||||
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
|
||||
return
|
||||
}
|
||||
max, errRet := strconv.ParseInt(portArray[1], 10, 64)
|
||||
if errRet != nil {
|
||||
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
|
||||
return
|
||||
}
|
||||
if max < min {
|
||||
err = fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect")
|
||||
return
|
||||
}
|
||||
for i := min; i <= max; i++ {
|
||||
cfg.PrivilegeAllowPorts[i] = struct{}{}
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "max_pool_count")
|
||||
if ok {
|
||||
v, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err == nil && v >= 0 {
|
||||
cfg.MaxPoolCount = v
|
||||
}
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "authentication_timeout")
|
||||
if ok {
|
||||
v, errRet := strconv.ParseInt(tmpStr, 10, 64)
|
||||
if errRet != nil {
|
||||
err = fmt.Errorf("Parse conf error: authentication_timeout is incorrect")
|
||||
return
|
||||
} else {
|
||||
cfg.AuthTimeout = v
|
||||
}
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "subdomain_host")
|
||||
if ok {
|
||||
cfg.SubDomainHost = strings.ToLower(strings.TrimSpace(tmpStr))
|
||||
}
|
||||
|
||||
tmpStr, ok = conf.Get("common", "heartbeat_timeout")
|
||||
if ok {
|
||||
v, errRet := strconv.ParseInt(tmpStr, 10, 64)
|
||||
if errRet != nil {
|
||||
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect")
|
||||
return
|
||||
} else {
|
||||
cfg.HeartBeatTimeout = v
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
28
models/consts/consts.go
Normal file
28
models/consts/consts.go
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package consts
|
||||
|
||||
var (
|
||||
// server status
|
||||
Idle string = "idle"
|
||||
Working string = "working"
|
||||
Closed string = "closed"
|
||||
|
||||
// proxy type
|
||||
TcpProxy string = "tcp"
|
||||
UdpProxy string = "udp"
|
||||
HttpProxy string = "http"
|
||||
HttpsProxy string = "https"
|
||||
)
|
21
models/errors/errors.go
Normal file
21
models/errors/errors.go
Normal file
@ -0,0 +1,21 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package errors
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrMsgType = errors.New("message type error")
|
||||
)
|
221
models/metric/server.go
Normal file
221
models/metric/server.go
Normal file
@ -0,0 +1,221 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metric
|
||||
|
||||
/*
|
||||
var (
|
||||
DailyDataKeepDays int = 7
|
||||
ServerMetricInfoMap map[string]*ServerMetric
|
||||
smMutex sync.RWMutex
|
||||
)
|
||||
|
||||
type ServerMetric struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
BindAddr string `json:"bind_addr"`
|
||||
ListenPort int64 `json:"listen_port"`
|
||||
CustomDomains []string `json:"custom_domains"`
|
||||
Locations []string `json:"locations"`
|
||||
Status string `json:"status"`
|
||||
UseEncryption bool `json:"use_encryption"`
|
||||
UseGzip bool `json:"use_gzip"`
|
||||
PrivilegeMode bool `json:"privilege_mode"`
|
||||
|
||||
// statistics
|
||||
CurrentConns int64 `json:"current_conns"`
|
||||
Daily []*DailyServerStats `json:"daily"`
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type DailyServerStats struct {
|
||||
Time string `json:"time"`
|
||||
FlowIn int64 `json:"flow_in"`
|
||||
FlowOut int64 `json:"flow_out"`
|
||||
TotalAcceptConns int64 `json:"total_accept_conns"`
|
||||
}
|
||||
|
||||
// for sort
|
||||
type ServerMetricList []*ServerMetric
|
||||
|
||||
func (l ServerMetricList) Len() int { return len(l) }
|
||||
func (l ServerMetricList) Less(i, j int) bool { return l[i].Name < l[j].Name }
|
||||
func (l ServerMetricList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
|
||||
|
||||
func init() {
|
||||
ServerMetricInfoMap = make(map[string]*ServerMetric)
|
||||
}
|
||||
|
||||
func (s *ServerMetric) clone() *ServerMetric {
|
||||
copy := *s
|
||||
copy.CustomDomains = make([]string, len(s.CustomDomains))
|
||||
var i int
|
||||
for i = range copy.CustomDomains {
|
||||
copy.CustomDomains[i] = s.CustomDomains[i]
|
||||
}
|
||||
|
||||
copy.Daily = make([]*DailyServerStats, len(s.Daily))
|
||||
for i = range copy.Daily {
|
||||
tmpDaily := *s.Daily[i]
|
||||
copy.Daily[i] = &tmpDaily
|
||||
}
|
||||
return ©
|
||||
}
|
||||
|
||||
func GetAllProxyMetrics() []*ServerMetric {
|
||||
result := make(ServerMetricList, 0)
|
||||
smMutex.RLock()
|
||||
for _, metric := range ServerMetricInfoMap {
|
||||
metric.mutex.RLock()
|
||||
tmpMetric := metric.clone()
|
||||
metric.mutex.RUnlock()
|
||||
result = append(result, tmpMetric)
|
||||
}
|
||||
smMutex.RUnlock()
|
||||
|
||||
// sort for result by proxy name
|
||||
sort.Sort(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// if proxyName isn't exist, return nil
|
||||
func GetProxyMetrics(proxyName string) *ServerMetric {
|
||||
smMutex.RLock()
|
||||
defer smMutex.RUnlock()
|
||||
metric, ok := ServerMetricInfoMap[proxyName]
|
||||
if ok {
|
||||
metric.mutex.RLock()
|
||||
tmpMetric := metric.clone()
|
||||
metric.mutex.RUnlock()
|
||||
return tmpMetric
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func SetProxyInfo(proxyName string, proxyType, bindAddr string,
|
||||
useEncryption, useGzip, privilegeMode bool, customDomains []string,
|
||||
locations []string, listenPort int64) {
|
||||
smMutex.Lock()
|
||||
info, ok := ServerMetricInfoMap[proxyName]
|
||||
if !ok {
|
||||
info = &ServerMetric{}
|
||||
info.Daily = make([]*DailyServerStats, 0)
|
||||
}
|
||||
info.Name = proxyName
|
||||
info.Type = proxyType
|
||||
info.UseEncryption = useEncryption
|
||||
info.UseGzip = useGzip
|
||||
info.PrivilegeMode = privilegeMode
|
||||
info.BindAddr = bindAddr
|
||||
info.ListenPort = listenPort
|
||||
info.CustomDomains = customDomains
|
||||
info.Locations = locations
|
||||
ServerMetricInfoMap[proxyName] = info
|
||||
smMutex.Unlock()
|
||||
}
|
||||
|
||||
func SetStatus(proxyName string, status int64) {
|
||||
smMutex.RLock()
|
||||
metric, ok := ServerMetricInfoMap[proxyName]
|
||||
smMutex.RUnlock()
|
||||
if ok {
|
||||
metric.mutex.Lock()
|
||||
metric.Status = consts.StatusStr[status]
|
||||
metric.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
type DealFuncType func(*DailyServerStats)
|
||||
|
||||
func DealDailyData(dailyData []*DailyServerStats, fn DealFuncType) (newDailyData []*DailyServerStats) {
|
||||
now := time.Now().Format("20060102")
|
||||
dailyLen := len(dailyData)
|
||||
if dailyLen == 0 {
|
||||
daily := &DailyServerStats{}
|
||||
daily.Time = now
|
||||
fn(daily)
|
||||
dailyData = append(dailyData, daily)
|
||||
} else {
|
||||
daily := dailyData[dailyLen-1]
|
||||
if daily.Time == now {
|
||||
fn(daily)
|
||||
} else {
|
||||
newDaily := &DailyServerStats{}
|
||||
newDaily.Time = now
|
||||
fn(newDaily)
|
||||
if dailyLen == DailyDataKeepDays {
|
||||
for i := 0; i < dailyLen-1; i++ {
|
||||
dailyData[i] = dailyData[i+1]
|
||||
}
|
||||
dailyData[dailyLen-1] = newDaily
|
||||
} else {
|
||||
dailyData = append(dailyData, newDaily)
|
||||
}
|
||||
}
|
||||
}
|
||||
return dailyData
|
||||
}
|
||||
|
||||
func OpenConnection(proxyName string) {
|
||||
smMutex.RLock()
|
||||
metric, ok := ServerMetricInfoMap[proxyName]
|
||||
smMutex.RUnlock()
|
||||
if ok {
|
||||
metric.mutex.Lock()
|
||||
metric.CurrentConns++
|
||||
metric.Daily = DealDailyData(metric.Daily, func(stats *DailyServerStats) {
|
||||
stats.TotalAcceptConns++
|
||||
})
|
||||
metric.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func CloseConnection(proxyName string) {
|
||||
smMutex.RLock()
|
||||
metric, ok := ServerMetricInfoMap[proxyName]
|
||||
smMutex.RUnlock()
|
||||
if ok {
|
||||
metric.mutex.Lock()
|
||||
metric.CurrentConns--
|
||||
metric.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func AddFlowIn(proxyName string, value int64) {
|
||||
smMutex.RLock()
|
||||
metric, ok := ServerMetricInfoMap[proxyName]
|
||||
smMutex.RUnlock()
|
||||
if ok {
|
||||
metric.mutex.Lock()
|
||||
metric.Daily = DealDailyData(metric.Daily, func(stats *DailyServerStats) {
|
||||
stats.FlowIn += value
|
||||
})
|
||||
metric.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func AddFlowOut(proxyName string, value int64) {
|
||||
smMutex.RLock()
|
||||
metric, ok := ServerMetricInfoMap[proxyName]
|
||||
smMutex.RUnlock()
|
||||
if ok {
|
||||
metric.mutex.Lock()
|
||||
metric.Daily = DealDailyData(metric.Daily, func(stats *DailyServerStats) {
|
||||
stats.FlowOut += value
|
||||
})
|
||||
metric.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
*/
|
122
models/msg/msg.go
Normal file
122
models/msg/msg.go
Normal file
@ -0,0 +1,122 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import "reflect"
|
||||
|
||||
const (
|
||||
TypeLogin = 'o'
|
||||
TypeLoginResp = '1'
|
||||
TypeNewProxy = 'p'
|
||||
TypeNewProxyResp = '2'
|
||||
TypeNewWorkConn = 'w'
|
||||
TypeReqWorkConn = 'r'
|
||||
TypeStartWorkConn = 's'
|
||||
TypePing = 'h'
|
||||
TypePong = '4'
|
||||
)
|
||||
|
||||
var (
|
||||
TypeMap map[byte]reflect.Type
|
||||
TypeStringMap map[reflect.Type]byte
|
||||
)
|
||||
|
||||
func init() {
|
||||
TypeMap = make(map[byte]reflect.Type)
|
||||
TypeStringMap = make(map[reflect.Type]byte)
|
||||
|
||||
TypeMap[TypeLogin] = getTypeFn((*Login)(nil))
|
||||
TypeMap[TypeLoginResp] = getTypeFn((*LoginResp)(nil))
|
||||
TypeMap[TypeNewProxy] = getTypeFn((*NewProxy)(nil))
|
||||
TypeMap[TypeNewProxyResp] = getTypeFn((*NewProxyResp)(nil))
|
||||
TypeMap[TypeNewWorkConn] = getTypeFn((*NewWorkConn)(nil))
|
||||
TypeMap[TypeReqWorkConn] = getTypeFn((*ReqWorkConn)(nil))
|
||||
TypeMap[TypeStartWorkConn] = getTypeFn((*StartWorkConn)(nil))
|
||||
TypeMap[TypePing] = getTypeFn((*Ping)(nil))
|
||||
TypeMap[TypePong] = getTypeFn((*Pong)(nil))
|
||||
|
||||
for k, v := range TypeMap {
|
||||
TypeStringMap[v] = k
|
||||
}
|
||||
}
|
||||
|
||||
func getTypeFn(obj interface{}) reflect.Type {
|
||||
return reflect.TypeOf(obj).Elem()
|
||||
}
|
||||
|
||||
// Message wraps socket packages for communicating between frpc and frps.
|
||||
type Message interface{}
|
||||
|
||||
// When frpc start, client send this message to login to server.
|
||||
type Login struct {
|
||||
Version string `json:"version"`
|
||||
Hostname string `json:"hostname"`
|
||||
Os string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
User string `json:"user"`
|
||||
PrivilegeKey string `json:"privilege_key"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
RunId string `json:"run_id"`
|
||||
|
||||
// Some global configures.
|
||||
PoolCount int `json:"pool_count"`
|
||||
}
|
||||
|
||||
type LoginResp struct {
|
||||
Version string `json:"version"`
|
||||
RunId string `json:"run_id"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// When frpc login success, send this message to frps for running a new proxy.
|
||||
type NewProxy struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
ProxyType string `json:"proxy_type"`
|
||||
UseEncryption bool `json:"use_encryption"`
|
||||
UseCompression bool `json:"use_compression"`
|
||||
|
||||
// tcp and udp only
|
||||
RemotePort int64 `json:"remote_port"`
|
||||
|
||||
// http and https only
|
||||
CustomDomains []string `json:"custom_domains"`
|
||||
SubDomain string `json:"subdomain"`
|
||||
Locations []string `json:"locations"`
|
||||
HostHeaderRewrite string `json:"host_header_rewrite"`
|
||||
HttpUser string `json:"http_user"`
|
||||
HttpPwd string `json:"http_pwd"`
|
||||
}
|
||||
|
||||
type NewProxyResp struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type NewWorkConn struct {
|
||||
RunId string `json:"run_id"`
|
||||
}
|
||||
|
||||
type ReqWorkConn struct {
|
||||
}
|
||||
|
||||
type StartWorkConn struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
}
|
||||
|
||||
type Ping struct {
|
||||
}
|
||||
|
||||
type Pong struct {
|
||||
}
|
69
models/msg/pack.go
Normal file
69
models/msg/pack.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/fatedier/frp/utils/errors"
|
||||
)
|
||||
|
||||
func unpack(typeByte byte, buffer []byte, msgIn Message) (msg Message, err error) {
|
||||
if msgIn == nil {
|
||||
t, ok := TypeMap[typeByte]
|
||||
if !ok {
|
||||
err = fmt.Errorf("Unsupported message type %b", typeByte)
|
||||
return
|
||||
}
|
||||
|
||||
msg = reflect.New(t).Interface().(Message)
|
||||
} else {
|
||||
msg = msgIn
|
||||
}
|
||||
|
||||
err = json.Unmarshal(buffer, &msg)
|
||||
return
|
||||
}
|
||||
|
||||
func UnPackInto(buffer []byte, msg Message) (err error) {
|
||||
_, err = unpack(' ', buffer, msg)
|
||||
return
|
||||
}
|
||||
|
||||
func UnPack(typeByte byte, buffer []byte) (msg Message, err error) {
|
||||
return unpack(typeByte, buffer, nil)
|
||||
}
|
||||
|
||||
func Pack(msg Message) ([]byte, error) {
|
||||
typeByte, ok := TypeStringMap[reflect.TypeOf(msg).Elem()]
|
||||
if !ok {
|
||||
return nil, errors.ErrMsgType
|
||||
}
|
||||
|
||||
content, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
buffer.WriteByte(typeByte)
|
||||
binary.Write(buffer, binary.BigEndian, int64(len(content)))
|
||||
buffer.Write(content)
|
||||
return buffer.Bytes(), nil
|
||||
}
|
86
models/msg/pack_test.go
Normal file
86
models/msg/pack_test.go
Normal file
@ -0,0 +1,86 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/fatedier/frp/utils/errors"
|
||||
)
|
||||
|
||||
type TestStruct struct{}
|
||||
|
||||
func TestPack(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
var (
|
||||
msg Message
|
||||
buffer []byte
|
||||
err error
|
||||
)
|
||||
|
||||
// error type
|
||||
msg = &TestStruct{}
|
||||
buffer, err = Pack(msg)
|
||||
assert.Error(err, errors.ErrMsgType.Error())
|
||||
|
||||
// correct
|
||||
msg = &Ping{}
|
||||
buffer, err = Pack(msg)
|
||||
assert.NoError(err)
|
||||
b := bytes.NewBuffer(nil)
|
||||
b.WriteByte(TypePing)
|
||||
binary.Write(b, binary.BigEndian, int64(2))
|
||||
b.WriteString("{}")
|
||||
assert.True(bytes.Equal(b.Bytes(), buffer))
|
||||
}
|
||||
|
||||
func TestUnPack(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
var (
|
||||
msg Message
|
||||
err error
|
||||
)
|
||||
|
||||
// error message type
|
||||
msg, err = UnPack('-', []byte("{}"))
|
||||
assert.Error(err)
|
||||
|
||||
// correct
|
||||
msg, err = UnPack(TypePong, []byte("{}"))
|
||||
assert.NoError(err)
|
||||
assert.Equal(getTypeFn(msg), getTypeFn((*Pong)(nil)))
|
||||
}
|
||||
|
||||
func TestUnPackInto(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
var err error
|
||||
|
||||
// correct type
|
||||
pongMsg := &Pong{}
|
||||
err = UnPackInto([]byte("{}"), pongMsg)
|
||||
assert.NoError(err)
|
||||
|
||||
// wrong type
|
||||
loginMsg := &Login{}
|
||||
err = UnPackInto([]byte(`{"version": 123}`), loginMsg)
|
||||
assert.Error(err)
|
||||
}
|
88
models/msg/process.go
Normal file
88
models/msg/process.go
Normal file
@ -0,0 +1,88 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
MaxMsgLength int64 = 10240
|
||||
)
|
||||
|
||||
func readMsg(c io.Reader) (typeByte byte, buffer []byte, err error) {
|
||||
buffer = make([]byte, 1)
|
||||
_, err = c.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
typeByte = buffer[0]
|
||||
if _, ok := TypeMap[typeByte]; !ok {
|
||||
err = fmt.Errorf("Message type error")
|
||||
return
|
||||
}
|
||||
|
||||
var length int64
|
||||
err = binary.Read(c, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if length > MaxMsgLength {
|
||||
err = fmt.Errorf("Message length exceed the limit")
|
||||
return
|
||||
}
|
||||
|
||||
buffer = make([]byte, length)
|
||||
n, err := io.ReadFull(c, buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if int64(n) != length {
|
||||
err = fmt.Errorf("Message format error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ReadMsg(c io.Reader) (msg Message, err error) {
|
||||
typeByte, buffer, err := readMsg(c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return UnPack(typeByte, buffer)
|
||||
}
|
||||
|
||||
func ReadMsgInto(c io.Reader, msg Message) (err error) {
|
||||
_, buffer, err := readMsg(c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return UnPackInto(buffer, msg)
|
||||
}
|
||||
|
||||
func WriteMsg(c io.Writer, msg interface{}) (err error) {
|
||||
buffer, err := Pack(msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = c.Write(buffer); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
97
models/msg/process_test.go
Normal file
97
models/msg/process_test.go
Normal file
@ -0,0 +1,97 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProcess(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
var (
|
||||
msg Message
|
||||
resMsg Message
|
||||
err error
|
||||
)
|
||||
// empty struct
|
||||
msg = &Ping{}
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
err = WriteMsg(buffer, msg)
|
||||
assert.NoError(err)
|
||||
|
||||
resMsg, err = ReadMsg(buffer)
|
||||
assert.NoError(err)
|
||||
assert.Equal(reflect.TypeOf(resMsg).Elem(), TypeMap[TypePing])
|
||||
|
||||
// normal message
|
||||
msg = &StartWorkConn{
|
||||
ProxyName: "test",
|
||||
}
|
||||
buffer = bytes.NewBuffer(nil)
|
||||
err = WriteMsg(buffer, msg)
|
||||
assert.NoError(err)
|
||||
|
||||
resMsg, err = ReadMsg(buffer)
|
||||
assert.NoError(err)
|
||||
assert.Equal(reflect.TypeOf(resMsg).Elem(), TypeMap[TypeStartWorkConn])
|
||||
|
||||
startWorkConnMsg, ok := resMsg.(*StartWorkConn)
|
||||
assert.True(ok)
|
||||
assert.Equal("test", startWorkConnMsg.ProxyName)
|
||||
|
||||
// ReadMsgInto correct
|
||||
msg = &Pong{}
|
||||
buffer = bytes.NewBuffer(nil)
|
||||
err = WriteMsg(buffer, msg)
|
||||
assert.NoError(err)
|
||||
|
||||
err = ReadMsgInto(buffer, msg)
|
||||
assert.NoError(err)
|
||||
|
||||
// ReadMsgInto error type
|
||||
content := []byte(`{"run_id": 123}`)
|
||||
buffer = bytes.NewBuffer(nil)
|
||||
buffer.WriteByte(TypeNewWorkConn)
|
||||
binary.Write(buffer, binary.BigEndian, int64(len(content)))
|
||||
buffer.Write(content)
|
||||
|
||||
resMsg = &NewWorkConn{}
|
||||
err = ReadMsgInto(buffer, resMsg)
|
||||
assert.Error(err)
|
||||
|
||||
// message format error
|
||||
buffer = bytes.NewBuffer([]byte("1234"))
|
||||
|
||||
resMsg = &NewProxyResp{}
|
||||
err = ReadMsgInto(buffer, resMsg)
|
||||
assert.Error(err)
|
||||
|
||||
// MaxLength, real message length is 2
|
||||
MaxMsgLength = 1
|
||||
msg = &Ping{}
|
||||
buffer = bytes.NewBuffer(nil)
|
||||
err = WriteMsg(buffer, msg)
|
||||
assert.NoError(err)
|
||||
|
||||
_, err = ReadMsg(buffer)
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
38
models/proto/tcp/process.go
Normal file
38
models/proto/tcp/process.go
Normal file
@ -0,0 +1,38 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Join two io.ReadWriteCloser and do some operations.
|
||||
func Join(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) (inCount int64, outCount int64) {
|
||||
var wait sync.WaitGroup
|
||||
pipe := func(to io.ReadWriteCloser, from io.ReadWriteCloser, count *int64) {
|
||||
defer to.Close()
|
||||
defer from.Close()
|
||||
defer wait.Done()
|
||||
|
||||
*count, _ = io.Copy(to, from)
|
||||
}
|
||||
|
||||
wait.Add(2)
|
||||
go pipe(c1, c2, &inCount)
|
||||
go pipe(c2, c1, &outCount)
|
||||
wait.Wait()
|
||||
return
|
||||
}
|
129
models/proto/tcp/process_test.go
Normal file
129
models/proto/tcp/process_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/fatedier/frp/utils/crypto"
|
||||
)
|
||||
|
||||
func TestJoin(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
text1 := "A document that gives tips for writing clear, idiomatic Go code. A must read for any new Go programmer. It augments the tour and the language specification, both of which should be read first."
|
||||
text2 := "A document that specifies the conditions under which reads of a variable in one goroutine can be guaranteed to observe values produced by writes to the same variable in a different goroutine."
|
||||
|
||||
// Forward bytes directly.
|
||||
pr, pw := io.Pipe()
|
||||
pr2, pw2 := io.Pipe()
|
||||
pr3, pw3 := io.Pipe()
|
||||
pr4, pw4 := io.Pipe()
|
||||
|
||||
conn1 := WrapReadWriteCloser(pr, pw2)
|
||||
conn2 := WrapReadWriteCloser(pr2, pw)
|
||||
conn3 := WrapReadWriteCloser(pr3, pw4)
|
||||
conn4 := WrapReadWriteCloser(pr4, pw3)
|
||||
|
||||
go func() {
|
||||
Join(conn2, conn3)
|
||||
}()
|
||||
|
||||
buf1 := make([]byte, 1024)
|
||||
buf2 := make([]byte, 1024)
|
||||
|
||||
conn1.Write([]byte(text1))
|
||||
conn4.Write([]byte(text2))
|
||||
|
||||
n, err = conn4.Read(buf1)
|
||||
assert.NoError(err)
|
||||
assert.Equal(text1, string(buf1[:n]))
|
||||
|
||||
n, err = conn1.Read(buf2)
|
||||
assert.NoError(err)
|
||||
assert.Equal(text2, string(buf2[:n]))
|
||||
|
||||
conn1.Close()
|
||||
conn2.Close()
|
||||
conn3.Close()
|
||||
conn4.Close()
|
||||
}
|
||||
|
||||
func TestJoinEncrypt(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
text1 := "1234567890"
|
||||
text2 := "abcdefghij"
|
||||
key := "authkey"
|
||||
|
||||
// Forward enrypted bytes.
|
||||
pr, pw := io.Pipe()
|
||||
pr2, pw2 := io.Pipe()
|
||||
pr3, pw3 := io.Pipe()
|
||||
pr4, pw4 := io.Pipe()
|
||||
pr5, pw5 := io.Pipe()
|
||||
pr6, pw6 := io.Pipe()
|
||||
|
||||
conn1 := WrapReadWriteCloser(pr, pw2)
|
||||
conn2 := WrapReadWriteCloser(pr2, pw)
|
||||
conn3 := WrapReadWriteCloser(pr3, pw4)
|
||||
conn4 := WrapReadWriteCloser(pr4, pw3)
|
||||
conn5 := WrapReadWriteCloser(pr5, pw6)
|
||||
conn6 := WrapReadWriteCloser(pr6, pw5)
|
||||
|
||||
r1, err := crypto.NewReader(conn3, []byte(key))
|
||||
assert.NoError(err)
|
||||
w1, err := crypto.NewWriter(conn3, []byte(key))
|
||||
assert.NoError(err)
|
||||
|
||||
r2, err := crypto.NewReader(conn4, []byte(key))
|
||||
assert.NoError(err)
|
||||
w2, err := crypto.NewWriter(conn4, []byte(key))
|
||||
assert.NoError(err)
|
||||
|
||||
go Join(conn2, WrapReadWriteCloser(r1, w1))
|
||||
go Join(WrapReadWriteCloser(r2, w2), conn5)
|
||||
|
||||
buf := make([]byte, 128)
|
||||
|
||||
conn1.Write([]byte(text1))
|
||||
conn6.Write([]byte(text2))
|
||||
|
||||
n, err = conn6.Read(buf)
|
||||
assert.NoError(err)
|
||||
assert.Equal(text1, string(buf[:n]))
|
||||
|
||||
n, err = conn1.Read(buf)
|
||||
assert.NoError(err)
|
||||
assert.Equal(text2, string(buf[:n]))
|
||||
|
||||
conn1.Close()
|
||||
conn2.Close()
|
||||
conn3.Close()
|
||||
conn4.Close()
|
||||
conn5.Close()
|
||||
conn6.Close()
|
||||
}
|
89
models/proto/tcp/tcp.go
Normal file
89
models/proto/tcp/tcp.go
Normal file
@ -0,0 +1,89 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/golang/snappy"
|
||||
|
||||
"github.com/fatedier/frp/utils/crypto"
|
||||
)
|
||||
|
||||
func WithEncryption(rwc io.ReadWriteCloser, key []byte) (res io.ReadWriteCloser, err error) {
|
||||
var (
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
)
|
||||
r, err = crypto.NewReader(rwc, key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w, err = crypto.NewWriter(rwc, key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res = WrapReadWriteCloser(r, w)
|
||||
return
|
||||
}
|
||||
|
||||
func WithCompression(rwc io.ReadWriteCloser) (res io.ReadWriteCloser) {
|
||||
var (
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
)
|
||||
r = snappy.NewReader(rwc)
|
||||
w = snappy.NewWriter(rwc)
|
||||
res = WrapReadWriteCloser(r, w)
|
||||
return
|
||||
}
|
||||
|
||||
func WrapReadWriteCloser(r io.Reader, w io.Writer) io.ReadWriteCloser {
|
||||
return &ReadWriteCloser{
|
||||
r: r,
|
||||
w: w,
|
||||
}
|
||||
}
|
||||
|
||||
type ReadWriteCloser struct {
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (rwc *ReadWriteCloser) Read(p []byte) (n int, err error) {
|
||||
return rwc.r.Read(p)
|
||||
}
|
||||
|
||||
func (rwc *ReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
return rwc.w.Write(p)
|
||||
}
|
||||
|
||||
func (rwc *ReadWriteCloser) Close() (errRet error) {
|
||||
var err error
|
||||
if rc, ok := rwc.r.(io.Closer); ok {
|
||||
err = rc.Close()
|
||||
if err != nil {
|
||||
errRet = err
|
||||
}
|
||||
}
|
||||
|
||||
if wc, ok := rwc.w.(io.Closer); ok {
|
||||
err = wc.Close()
|
||||
if err != nil {
|
||||
errRet = err
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
100
models/proto/tcp/tcp_test.go
Normal file
100
models/proto/tcp/tcp_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWithCompression(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Forward compression bytes.
|
||||
pr, pw := io.Pipe()
|
||||
pr2, pw2 := io.Pipe()
|
||||
|
||||
conn1 := WrapReadWriteCloser(pr, pw2)
|
||||
conn2 := WrapReadWriteCloser(pr2, pw)
|
||||
|
||||
compressionStream1 := WithCompression(conn1)
|
||||
compressionStream2 := WithCompression(conn2)
|
||||
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
|
||||
text := "1234567812345678"
|
||||
buf := make([]byte, 256)
|
||||
|
||||
go compressionStream1.Write([]byte(text))
|
||||
n, err = compressionStream2.Read(buf)
|
||||
assert.NoError(err)
|
||||
assert.Equal(text, string(buf[:n]))
|
||||
|
||||
go compressionStream2.Write([]byte(text))
|
||||
n, err = compressionStream1.Read(buf)
|
||||
assert.NoError(err)
|
||||
assert.Equal(text, string(buf[:n]))
|
||||
}
|
||||
|
||||
func TestWithEncryption(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
text1 := "Go is expressive, concise, clean, and efficient. Its concurrency mechanisms make it easy to write programs that get the most out of multicore and networked machines, while its novel type system enables flexible and modular program construction. Go compiles quickly to machine code yet has the convenience of garbage collection and the power of run-time reflection. It's a fast, statically typed, compiled language that feels like a dynamically typed, interpreted language."
|
||||
text2 := "An interactive introduction to Go in three sections. The first section covers basic syntax and data structures; the second discusses methods and interfaces; and the third introduces Go's concurrency primitives. Each section concludes with a few exercises so you can practice what you've learned. You can take the tour online or install it locally with"
|
||||
key := "authkey"
|
||||
|
||||
// Forward enrypted bytes.
|
||||
pr, pw := io.Pipe()
|
||||
pr2, pw2 := io.Pipe()
|
||||
pr3, pw3 := io.Pipe()
|
||||
pr4, pw4 := io.Pipe()
|
||||
pr5, pw5 := io.Pipe()
|
||||
pr6, pw6 := io.Pipe()
|
||||
|
||||
conn1 := WrapReadWriteCloser(pr, pw2)
|
||||
conn2 := WrapReadWriteCloser(pr2, pw)
|
||||
conn3 := WrapReadWriteCloser(pr3, pw4)
|
||||
conn4 := WrapReadWriteCloser(pr4, pw3)
|
||||
conn5 := WrapReadWriteCloser(pr5, pw6)
|
||||
conn6 := WrapReadWriteCloser(pr6, pw5)
|
||||
|
||||
encryptStream1, err := WithEncryption(conn3, []byte(key))
|
||||
assert.NoError(err)
|
||||
encryptStream2, err := WithEncryption(conn4, []byte(key))
|
||||
assert.NoError(err)
|
||||
|
||||
go Join(conn2, encryptStream1)
|
||||
go Join(encryptStream2, conn5)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
|
||||
conn1.Write([]byte(text1))
|
||||
conn6.Write([]byte(text2))
|
||||
|
||||
n, err = conn6.Read(buf)
|
||||
assert.NoError(err)
|
||||
assert.Equal(text1, string(buf[:n]))
|
||||
|
||||
n, err = conn1.Read(buf)
|
||||
assert.NoError(err)
|
||||
}
|
72
models/proto/udp/udp.go
Normal file
72
models/proto/udp/udp.go
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net"
|
||||
)
|
||||
|
||||
type UdpPacket struct {
|
||||
Content []byte `json:"-"`
|
||||
Src *net.UDPAddr `json:"-"`
|
||||
Dst *net.UDPAddr `json:"-"`
|
||||
|
||||
EncodeContent string `json:"content"`
|
||||
SrcStr string `json:"src"`
|
||||
DstStr string `json:"dst"`
|
||||
}
|
||||
|
||||
func NewUdpPacket(content []byte, src, dst *net.UDPAddr) *UdpPacket {
|
||||
up := &UdpPacket{
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
EncodeContent: base64.StdEncoding.EncodeToString(content),
|
||||
SrcStr: src.String(),
|
||||
DstStr: dst.String(),
|
||||
}
|
||||
return up
|
||||
}
|
||||
|
||||
// parse one udp packet struct to bytes
|
||||
func (up *UdpPacket) Pack() []byte {
|
||||
b, _ := json.Marshal(up)
|
||||
return b
|
||||
}
|
||||
|
||||
// parse from bytes to UdpPacket struct
|
||||
func (up *UdpPacket) UnPack(packet []byte) error {
|
||||
err := json.Unmarshal(packet, &up)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
up.Content, err = base64.StdEncoding.DecodeString(up.EncodeContent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
up.Src, err = net.ResolveUDPAddr("udp", up.SrcStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
up.Dst, err = net.ResolveUDPAddr("udp", up.DstStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
50
models/proto/udp/udp_test.go
Normal file
50
models/proto/udp/udp_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
content string = "udp packet test"
|
||||
src string = "1.1.1.1:1000"
|
||||
dst string = "2.2.2.2:2000"
|
||||
|
||||
udpMsg *UdpPacket
|
||||
)
|
||||
|
||||
func init() {
|
||||
srcAddr, _ := net.ResolveUDPAddr("udp", src)
|
||||
dstAddr, _ := net.ResolveUDPAddr("udp", dst)
|
||||
udpMsg = NewUdpPacket([]byte(content), srcAddr, dstAddr)
|
||||
}
|
||||
|
||||
func TestPack(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
msg := udpMsg.Pack()
|
||||
assert.Equal(string(msg), `{"content":"dWRwIHBhY2tldCB0ZXN0","src":"1.1.1.1:1000","dst":"2.2.2.2:2000"}`)
|
||||
}
|
||||
|
||||
func TestUnpack(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
udpMsg.UnPack([]byte(`{"content":"dWRwIHBhY2tldCB0ZXN0","src":"1.1.1.1:1000","dst":"2.2.2.2:2000"}`))
|
||||
assert.Equal(content, string(udpMsg.Content))
|
||||
assert.Equal(src, udpMsg.Src.String())
|
||||
assert.Equal(dst, udpMsg.Dst.String())
|
||||
}
|
Reference in New Issue
Block a user