diff --git a/README.md b/README.md index 9b49e7f..0cbabfd 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,12 @@ With nftables: ```sh cat > /etc/sparoid.ini << EOF -bind = 0.0.0.0 +bind = :: port = 8484 key = $SPAROID_KEY hmac-key = $SPAROID_HMAC_KEY nftables-cmd = add element inet filter sparoid { %s } +nftablesv6-cmd = add element inet filter sparoid6 { %s } EOF cat > /etc/nftables.conf << EOF @@ -46,7 +47,8 @@ flush ruleset table inet filter { chain prerouting { type filter hook prerouting priority -300 - udp dport 8484 meter rate-limit-sparoid { ip saddr limit rate over 1/second burst 1 packets } counter drop + udp dport 8484 meter rate-limit-sparoid { ip saddr limit rate over 1/second burst 8 packets } counter drop + udp dport 8484 meter rate-limit-sparoid6 { ip6 saddr limit rate over 1/second burst 8 packets } counter drop udp dport 8484 notrack } @@ -60,6 +62,7 @@ table inet filter { udp dport 8484 accept ip saddr @jumphosts tcp dport ssh accept ip saddr @sparoid tcp dport ssh accept + ip6 saddr @sparoid6 tcp dport ssh accept } set sparoid { @@ -68,6 +71,12 @@ table inet filter { timeout 5s } + set sparoid6 { + type ipv6_addr + flags timeout, interval + timeout 5s + } + set jumphosts { type ipv4_addr elements = { 10.10.10.10 } diff --git a/spec/message_spec.cr b/spec/message_spec.cr new file mode 100644 index 0000000..8d49dec --- /dev/null +++ b/spec/message_spec.cr @@ -0,0 +1,152 @@ +require "./spec_helper" + +describe Sparoid::Message do + describe ".from_ip" do + it "creates message from IPv4 string" do + msg = Sparoid::Message.from_ip("192.168.1.100") + msg.family.should eq Socket::Family::INET + msg.ip_string.should eq "192.168.1.100" + msg.ip.size.should eq 4 + end + + it "creates message from IPv6 string" do + msg = Sparoid::Message.from_ip("2001:0db8:85a3::8a2e:0370:7334") + msg.family.should eq Socket::Family::INET6 + msg.ip_string.should eq "2001:0db8:85a3:0000:0000:8a2e:0370:7334" + msg.ip.size.should eq 16 + end + + it "strips IPv4-mapped IPv6 to plain IPv4" do + msg = Sparoid::Message.from_ip("::ffff:192.168.1.1") + msg.family.should eq Socket::Family::INET + msg.ip_string.should eq "192.168.1.1" + msg.ip.size.should eq 4 + end + + it "raises on invalid string" do + expect_raises(Exception, "Invalid IP address: not-an-ip") do + Sparoid::Message.from_ip("not-an-ip") + end + end + end + + describe "#ip_string" do + it "formats localhost" do + msg = Sparoid::Message.from_ip("127.0.0.1") + msg.ip_string.should eq "127.0.0.1" + end + + it "formats 0.0.0.0" do + msg = Sparoid::Message.from_ip("0.0.0.0") + msg.ip_string.should eq "0.0.0.0" + end + + it "formats 255.255.255.255" do + msg = Sparoid::Message.from_ip("255.255.255.255") + msg.ip_string.should eq "255.255.255.255" + end + + it "formats ::1 (loopback)" do + msg = Sparoid::Message.from_ip("::1") + msg.ip_string.should eq "0000:0000:0000:0000:0000:0000:0000:0001" + end + + it "formats :: (all zeros)" do + msg = Sparoid::Message.from_ip("::") + msg.ip_string.should eq "0000:0000:0000:0000:0000:0000:0000:0000" + end + + it "formats 2001:db8::" do + msg = Sparoid::Message.from_ip("2001:db8::") + msg.ip_string.should eq "2001:0db8:0000:0000:0000:0000:0000:0000" + end + + it "formats fe80::1 (link-local)" do + msg = Sparoid::Message.from_ip("fe80::1") + msg.ip_string.should eq "fe80:0000:0000:0000:0000:0000:0000:0001" + end + + it "formats ff02::1 (multicast)" do + msg = Sparoid::Message.from_ip("ff02::1") + msg.ip_string.should eq "ff02:0000:0000:0000:0000:0000:0000:0001" + end + end + + describe "serialization round-trip" do + it "serializes and deserializes IPv4" do + original = Sparoid::Message.from_ip("10.20.30.40") + slice = original.to_slice(IO::ByteFormat::NetworkEndian) + slice.size.should eq 32 + + io = IO::Memory.new(slice) + parsed = Sparoid::Message.from_io(io, IO::ByteFormat::NetworkEndian) + parsed.family.should eq Socket::Family::INET + parsed.ip_string.should eq "10.20.30.40" + parsed.ts.should eq original.ts + parsed.nounce.should eq original.nounce + end + + it "serializes and deserializes IPv6" do + original = Sparoid::Message.from_ip("2001:db8::1") + slice = original.to_slice(IO::ByteFormat::NetworkEndian) + slice.size.should eq 44 + + io = IO::Memory.new(slice) + parsed = Sparoid::Message.from_io(io, IO::ByteFormat::NetworkEndian) + parsed.family.should eq Socket::Family::INET6 + parsed.ip_string.should eq "2001:0db8:0000:0000:0000:0000:0000:0001" + parsed.ts.should eq original.ts + parsed.nounce.should eq original.nounce + end + + it "round-trips IPv4-mapped IPv6 as plain IPv4" do + original = Sparoid::Message.from_ip("::ffff:10.20.30.40") + slice = original.to_slice(IO::ByteFormat::NetworkEndian) + slice.size.should eq 32 + + io = IO::Memory.new(slice) + parsed = Sparoid::Message.from_io(io, IO::ByteFormat::NetworkEndian) + parsed.family.should eq Socket::Family::INET + parsed.ip_string.should eq "10.20.30.40" + parsed.ts.should eq original.ts + parsed.nounce.should eq original.nounce + end + + it "preserves timestamp and nonce" do + original = Sparoid::Message.from_ip("1.2.3.4") + slice = original.to_slice(IO::ByteFormat::NetworkEndian) + io = IO::Memory.new(slice) + parsed = Sparoid::Message.from_io(io, IO::ByteFormat::NetworkEndian) + parsed.ts.should eq original.ts + parsed.nounce.should eq original.nounce + parsed.ip.should eq original.ip + end + end + + describe ".from_io" do + it "raises on unsupported version" do + slice = Bytes.new(32) + IO::ByteFormat::NetworkEndian.encode(99_i32, slice[0, 4]) + io = IO::Memory.new(slice) + expect_raises(Exception, "Unsupported message version: 99") do + Sparoid::Message.from_io(io, IO::ByteFormat::NetworkEndian) + end + end + end + + describe "timestamp and nonce" do + it "generates unique nonces" do + msg1 = Sparoid::Message.from_ip("1.2.3.4") + msg2 = Sparoid::Message.from_ip("1.2.3.4") + msg1.nounce.should_not eq msg2.nounce + end + + it "generates timestamps close to current time" do + before = Time.utc.to_unix_ms + msg = Sparoid::Message.from_ip("1.2.3.4") + after = Time.utc.to_unix_ms + msg.ts.should be >= before + msg.ts.should be <= after + end + end +end diff --git a/spec/sparoid_spec.cr b/spec/sparoid_spec.cr index 603f788..9030ddf 100644 --- a/spec/sparoid_spec.cr +++ b/spec/sparoid_spec.cr @@ -1,4 +1,5 @@ require "./spec_helper" +require "socket" KEYS = Array(String).new(2) { Random::Secure.hex(32) } HMAC_KEYS = Array(String).new(2) { Random::Secure.hex(32) } @@ -7,7 +8,7 @@ ADDRESS = Socket::IPAddress.new("127.0.0.1", 8484) describe Sparoid::Server do it "works" do last_ip = nil - cb = ->(ip : String) { last_ip = ip } + cb = ->(ip : String, _family : Socket::Family) { last_ip = ip } s = Sparoid::Server.new(KEYS, HMAC_KEYS, cb, ADDRESS) s.bind spawn s.listen @@ -21,7 +22,7 @@ describe Sparoid::Server do end it "fails invalid packet lengths" do - cb = ->(ip : String) { ip.should be_nil } + cb = ->(ip : String, _family : Socket::Family) { ip.should be_nil } s = Sparoid::Server.new(KEYS, HMAC_KEYS, cb, ADDRESS) s.bind spawn s.listen @@ -36,7 +37,7 @@ describe Sparoid::Server do end it "fails invalid key" do - cb = ->(ip : String) { ip.should be_nil } + cb = ->(ip : String, _family : Socket::Family) { ip.should be_nil } s = Sparoid::Server.new(KEYS, HMAC_KEYS, cb, ADDRESS) s.bind spawn s.listen @@ -49,7 +50,7 @@ describe Sparoid::Server do end it "fails invalid hmac key" do - cb = ->(ip : String) { ip.should be_nil } + cb = ->(ip : String, _family : Socket::Family) { ip.should be_nil } s = Sparoid::Server.new(KEYS, HMAC_KEYS, cb, ADDRESS) s.bind spawn s.listen @@ -63,7 +64,7 @@ describe Sparoid::Server do it "client can cache IP" do accepted = 0 - cb = ->(_ip : String) { accepted += 1 } + cb = ->(_ip : String, _family : Socket::Family) { accepted += 1 } s = Sparoid::Server.new(KEYS, HMAC_KEYS, cb, ADDRESS) s.bind spawn s.listen @@ -79,7 +80,7 @@ describe Sparoid::Server do it "works with two keys" do accepted = 0 - cb = ->(_ip : String) { accepted += 1 } + cb = ->(_ip : String, _family : Socket::Family) { accepted += 1 } s = Sparoid::Server.new(KEYS, HMAC_KEYS, cb, ADDRESS) s.bind spawn s.listen @@ -95,12 +96,12 @@ describe Sparoid::Server do it "client can send another IP" do last_ip = nil - cb = ->(ip : String) { last_ip = ip } + cb = ->(ip : String, _family : Socket::Family) { last_ip = ip } address = Socket::IPAddress.new("0.0.0.0", ADDRESS.port) s = Sparoid::Server.new(KEYS, HMAC_KEYS, cb, address) s.bind spawn s.listen - Sparoid::Client.send(KEYS.first, HMAC_KEYS.first, "0.0.0.0", address.port, StaticArray[1u8, 1u8, 1u8, 1u8]) + Sparoid::Client.send(KEYS.first, HMAC_KEYS.first, "0.0.0.0", address.port, "1.1.1.1") Fiber.yield s.@seen_nounces.size.should eq 1 last_ip.should eq "1.1.1.1" @@ -108,9 +109,27 @@ describe Sparoid::Server do s.try &.close end + it "can parse manually constructed messages" do + last_ip = nil + cb = ->(ip : String, _family : Socket::Family) { last_ip = ip } + s = Sparoid::Server.new(KEYS, HMAC_KEYS, cb, ADDRESS) + s.bind + spawn s.listen + msg = Sparoid::Message.from_ip("127.0.0.1") + data = Sparoid::Client.generate_package(KEYS.first, HMAC_KEYS.first, msg) + socket = UDPSocket.new + socket.send data, to: ADDRESS + socket.close + Fiber.yield + s.@seen_nounces.size.should eq 1 + last_ip.should eq "127.0.0.1" + ensure + s.try &.close + end + it "can accept IPv4 connections on ::" do last_ip = nil - cb = ->(ip : String) { last_ip = ip } + cb = ->(ip : String, _family : Socket::Family) { last_ip = ip } address = Socket::IPAddress.new("::", ADDRESS.port) s = Sparoid::Server.new(KEYS, HMAC_KEYS, cb, address) s.bind @@ -122,17 +141,4 @@ describe Sparoid::Server do ensure s.try &.close end - - it "raises on unsupported message version" do - ts = Time.utc.to_unix_ms - nounce = StaticArray(UInt8, 16).new(0_u8) - Random::Secure.random_bytes(nounce.to_slice) - msg = Sparoid::Message.new(2, ts, nounce, StaticArray[127u8, 0u8, 0u8, 1u8]) - - msg.to_slice(IO::ByteFormat::NetworkEndian).tap do |slice| - expect_raises(Exception, "Unsupported message version: 2") do - Sparoid::Message.from_io(IO::Memory.new(slice), IO::ByteFormat::NetworkEndian) - end - end - end end diff --git a/src/client.cr b/src/client.cr index 544b0b8..284b86d 100644 --- a/src/client.cr +++ b/src/client.cr @@ -6,6 +6,8 @@ require "fdpass" require "./message" require "./public_ip" require "ini" +require "./ipv6" +require "wait_group" module Sparoid class Client @@ -27,63 +29,59 @@ module Sparoid self.new(key, hmac_key) end - def initialize(@key : String, @hmac_key : String, @ip = PublicIP.by_http) + def initialize(@key : String, @hmac_key : String) end def send(host : String, port : Int32) - self.class.send(@key, @hmac_key, host, port, @ip) + self.class.send(@key, @hmac_key, host, port) end - def self.send(key : String, hmac_key : String, host : String, port : Int32, ip = PublicIP.by_http) : Array(String) - ip = StaticArray[127u8, 0u8, 0u8, 1u8] if {"localhost", "127.0.0.1"}.includes? host - package = generate_package(key, hmac_key, ip) - udp_send(host, port, package).tap do + def self.send(key : String, hmac_key : String, host : String, port : Int32, public_ip : String? = nil) : Array(String) + udp_send(host, port, key, hmac_key, public_ip).tap do sleep 20.milliseconds # sleep a short while to allow the receiver to parse and execute the packet end end - def self.generate_package(key, hmac_key, ip) : Bytes + def self.generate_package(key, hmac_key, message : Message) : Bytes key = key.hexbytes hmac_key = hmac_key.hexbytes raise ArgumentError.new("Key must be 32 bytes hex encoded") if key.bytesize != 32 raise ArgumentError.new("HMAC key must be 32 bytes hex encoded") if hmac_key.bytesize != 32 - - msg = Message.new(ip) - encrypt(key, hmac_key, msg.to_slice(IO::ByteFormat::NetworkEndian)) + encrypt(key, hmac_key, message.to_slice(IO::ByteFormat::NetworkEndian)) end def self.fdpass(ips, port) : NoReturn - ch = Channel(Nil).new + wg = WaitGroup.new ips.each do |ip| - spawn do - socket = TCPSocket.new - socket.connect(Socket::IPAddress.new(ip, port), timeout: 10) + wg.spawn do + ipaddr = Socket::IPAddress.new(ip, port) + socket = TCPSocket.new ipaddr.family + socket.connect(ipaddr, timeout: 10) FDPass.send_fd(1, socket.fd) - # exit as soon as possible so no other fiber also succefully connects - exit 0 - rescue - ch.send(nil) + exit 0 # exit as soon as possible so no other fiber also succefully connects end end - ips.size.times { ch.receive } + wg.wait exit 1 # only if all connects fails end - # Send to all resolved IPs for the hostname - private def self.udp_send(host, port, data) : Array(String) - host_addresses = Socket::Addrinfo.udp(host, port, Socket::Family::INET) - socket = Socket.udp(Socket::Family::INET) - Socket.set_blocking(socket.fd, true) + # Send to all resolved IPs for the hostname, prioritizing IPv6 + private def self.udp_send(host, port, key : String, hmac_key : String, public_ip : String? = nil) : Array(String) + host_addresses = Socket::Addrinfo.udp(host, port) host_addresses.each do |addrinfo| + packages = generate_messages(addrinfo.ip_address, public_ip).map { |message| generate_package(key, hmac_key, message) } + socket = UDPSocket.new(addrinfo.family) begin - socket.send data, to: addrinfo.ip_address + packages.each do |data| + socket.send data, to: addrinfo.ip_address + end rescue ex STDERR << "Sparoid error sending " << ex.inspect << "\n" + ensure + socket.close end end host_addresses.map &.ip_address.address - ensure - socket.close if socket end private def self.encrypt(key, hmac_key, data) : Bytes @@ -109,5 +107,31 @@ module Sparoid STDOUT << "key = " << cipher.random_key.hexstring << "\n" STDOUT << "hmac-key = " << Random::Secure.hex(32) << "\n" end + + # Generate messages for all public IPs (IPv4 first, server may rate-limit). + private def self.generate_messages(host : Socket::IPAddress, public_ip : String? = nil) : Array(Message) + return [Message.from_ip(public_ip)] if public_ip + return local_ips(host).map { |ip| Message.from_ip(ip) } if host.loopback? || host.unspecified? + + [public_ipv4, public_ipv6].compact.map { |ip| Message.from_ip(ip) } + end + + # IPv4: from icanhazip + private def self.public_ipv4 : String? + PublicIP.ipv4 + end + + # IPv6: prefer OS-selected outgoing address, fall back to icanhazip + private def self.public_ipv6 : String? + IPv6.public_ipv6 || PublicIP.ipv6 + end + + private def self.local_ips(host : Socket::IPAddress) : Array(String) + if host.family == Socket::Family::INET + ["127.0.0.1"] + else + ["::1"] + end + end end end diff --git a/src/config.cr b/src/config.cr index 4ff1942..4820102 100644 --- a/src/config.cr +++ b/src/config.cr @@ -12,6 +12,7 @@ module Sparoid getter close_cmd = "" getter config_file = "/etc/sparoid.ini" getter nftables_cmd = "" + getter nftablesv6_cmd = "" def initialize parse_options @@ -47,13 +48,14 @@ module Sparoid # ignore sections, assume there's only the empty values.each do |k, v| case k - when "key" then @keys << v - when "hmac-key" then @hmac_keys << v - when "bind" then @host = v - when "port" then @port = v.to_i - when "open-cmd" then @open_cmd = v - when "close-cmd" then @close_cmd = v - when "nftables-cmd" then @nftables_cmd = v + when "key" then @keys << v + when "hmac-key" then @hmac_keys << v + when "bind" then @host = v + when "port" then @port = v.to_i + when "open-cmd" then @open_cmd = v + when "close-cmd" then @close_cmd = v + when "nftables-cmd" then @nftables_cmd = v + when "nftablesv6-cmd" then @nftablesv6_cmd = v end end end diff --git a/src/ipv6.cr b/src/ipv6.cr new file mode 100644 index 0000000..6ab1fc1 --- /dev/null +++ b/src/ipv6.cr @@ -0,0 +1,100 @@ +require "socket" + +class IPv6 + GOOGLE_DNS = Socket::IPAddress.new("2001:4860:4860::8888", 53) + + # Get the public IPv6 address by asking the OS which source address + # it would use to reach a well-known IPv6 destination. + # Returns nil if no global IPv6 address is available. + def self.public_ipv6 : String? + socket = UDPSocket.new(Socket::Family::INET6) + begin + socket.connect(GOOGLE_DNS) + addr = socket.local_address + return addr.address if global?(addr.address) + rescue + ensure + socket.close + end + nil + end + + # Check if an IPv6 address is globally reachable. + # Based on Rust std::net::Ipv6Addr::is_global (IETF RFC 4291, RFC 6890, etc.) + # ameba:disable Metrics/CyclomaticComplexity + def self.global?(ip : String) : Bool + s = Socket::IPAddress.parse_v6_fields?(ip) + return false unless s + return false if unspecified?(s) || loopback?(s) + return false if ipv4_mapped?(s) + return false if ipv4_ipv6_translation?(s) + return false if discard_only?(s) + return false if ietf_protocol_non_global?(s) + return false if sixto4?(s) + return false if documentation?(s) + return false if segment_routing?(s) + return false if unique_local?(s) + return false if link_local?(s) + true + end + + private def self.unspecified?(s) : Bool + s == StaticArray[0u16, 0, 0, 0, 0, 0, 0, 0] + end + + private def self.loopback?(s) : Bool + s == StaticArray[0u16, 0, 0, 0, 0, 0, 0, 1] + end + + # ::ffff:0:0/96 + private def self.ipv4_mapped?(s) : Bool + s[0] == 0 && s[1] == 0 && s[2] == 0 && s[3] == 0 && s[4] == 0 && s[5] == 0xffff + end + + # 64:ff9b:1::/48 + private def self.ipv4_ipv6_translation?(s) : Bool + s[0] == 0x64 && s[1] == 0xff9b && s[2] == 1 + end + + # 100::/64 + private def self.discard_only?(s) : Bool + s[0] == 0x100 && s[1] == 0 && s[2] == 0 && s[3] == 0 + end + + # 2001::/23 minus globally reachable sub-ranges + # ameba:disable Metrics/CyclomaticComplexity + private def self.ietf_protocol_non_global?(s) : Bool + return false unless s[0] == 0x2001 && s[1] < 0x200 + # PCP/TURN Anycast (2001:1::1, 2001:1::2) + return false if s[1] == 1 && s[2] == 0 && s[3] == 0 && s[4] == 0 && s[5] == 0 && s[6] == 0 && (s[7] == 1 || s[7] == 2) + return false if s[1] == 3 # AMT (2001:3::/32) + return false if s[1] == 4 && s[2] == 0x112 # AS112-v6 (2001:4:112::/48) + return false if s[1] >= 0x20 && s[1] <= 0x3f # ORCHIDv2 / Drone DETs + true + end + + # 2002::/16 + private def self.sixto4?(s) : Bool + s[0] == 0x2002 + end + + # 2001:db8::/32, 3fff:0000::/20 + private def self.documentation?(s) : Bool + (s[0] == 0x2001 && s[1] == 0xdb8) || (s[0] == 0x3fff && s[1] <= 0x0fff) + end + + # 5f00::/16 + private def self.segment_routing?(s) : Bool + s[0] == 0x5f00 + end + + # fc00::/7 + private def self.unique_local?(s) : Bool + s[0] & 0xfe00 == 0xfc00 + end + + # fe80::/10 + private def self.link_local?(s) : Bool + s[0] & 0xffc0 == 0xfe80 + end +end diff --git a/src/message.cr b/src/message.cr index 52c17ca..6665402 100644 --- a/src/message.cr +++ b/src/message.cr @@ -1,46 +1,90 @@ require "random/secure" +require "socket" module Sparoid struct Message - getter version : Int32, ts : Int64, nounce : StaticArray(UInt8, 16), ip : StaticArray(UInt8, 4) + VERSION = 1_i32 - def initialize(@version, @ts, @nounce, @ip) + getter ts : Int64, nounce : StaticArray(UInt8, 16), ip : Bytes + + def initialize(@ts, @nounce, @ip : Bytes) end - def initialize(@ip) - @version = 1 + def initialize(@ip : Bytes) @ts = Time.utc.to_unix_ms @nounce = uninitialized UInt8[16] Random::Secure.random_bytes(@nounce.to_slice) end - def to_io(io, format) - io.write_bytes @version, format - io.write_bytes @ts, format - io.write @nounce - io.write @ip + def family : Socket::Family + @ip.size == 4 ? Socket::Family::INET : Socket::Family::INET6 + end + + def ip_string : String + if @ip.size == 4 + String.build(15) do |str| + 4.times do |i| + str << '.' unless i == 0 + str << @ip[i] + end + end + else + String.build(39) do |str| + 8.times do |i| + str << ':' unless i == 0 + str << '0' if @ip[i * 2] < 0x10 + @ip[i * 2].to_s(str, 16) + str << '0' if @ip[i * 2 + 1] < 0x10 + @ip[i * 2 + 1].to_s(str, 16) + end + end + end end def to_slice(format : IO::ByteFormat) : Bytes - slice = Bytes.new(32) # version (4) + timestamp (8) + nounce (16) + ip (4) - format.encode(@version, slice[0, 4]) + size = 28 + @ip.size # version (4) + timestamp (8) + nounce (16) + ip (4 or 16) + slice = Bytes.new(size) + format.encode(VERSION, slice[0, 4]) format.encode(@ts, slice[4, 8]) - @nounce.to_slice.copy_to slice[12, @nounce.size] - @ip.to_slice.copy_to slice[28, @ip.size] + @nounce.to_slice.copy_to(slice[12, 16]) + @ip.copy_to(slice[28, @ip.size]) slice end - def self.from_io(io, format) + def self.from_io(io : IO, format : IO::ByteFormat) : Message version = Int32.from_io(io, format) - if version != 1 - raise "Unsupported message version: #{version}" - end + raise "Unsupported message version: #{version}" unless version == VERSION ts = Int64.from_io(io, format) nounce = uninitialized UInt8[16] io.read_fully(nounce.to_slice) - ip = uninitialized UInt8[4] - io.read_fully(ip.to_slice) - self.new(version, ts, nounce, ip) + buf = Bytes.new(16) + count = io.read(buf) + raise "Invalid IP: expected 4 or 16 bytes, got #{count}" unless count == 4 || count == 16 + Message.new(ts, nounce, strip_mapped_ipv4(buf[0, count])) + end + + def self.from_ip(ip : String) : Message + if fields = Socket::IPAddress.parse_v4_fields?(ip) + Message.new(Bytes[fields[0], fields[1], fields[2], fields[3]]) + elsif fields = Socket::IPAddress.parse_v6_fields?(ip) + ip_bytes = Bytes.new(16) + fields.each_with_index do |segment, i| + IO::ByteFormat::NetworkEndian.encode(segment, ip_bytes[i * 2, 2]) + end + Message.new(strip_mapped_ipv4(ip_bytes)) + else + raise "Invalid IP address: #{ip}" + end + end + + # Convert ::ffff:x.x.x.x (IPv4-mapped IPv6) to plain 4-byte IPv4 + private def self.strip_mapped_ipv4(ip : Bytes) : Bytes + if ip.size == 16 && + ip[0, 12] == Bytes[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff] + ip[12, 4].dup + else + ip.dup + end end end end diff --git a/src/public_ip.cr b/src/public_ip.cr index 808158a..14434cb 100644 --- a/src/public_ip.cr +++ b/src/public_ip.cr @@ -4,56 +4,38 @@ require "http/client" module Sparoid class PublicIP - # https://code.blogs.iiidefix.net/posts/get-public-ip-using-dns/ - def self.by_dns : StaticArray(UInt8, 4) - with_cache do - socket = UDPSocket.new - socket.connect("208.67.222.222", 53) # resolver1.opendns.com - header = DNS::Header.new(op_code: DNS::OpCode::Query, recursion_desired: false) - message = DNS::Message.new(header: header) - message.questions << DNS::Question.new(name: DNS::Name.new("myip.opendns.com"), query_type: DNS::RecordType::A) - message.to_socket socket - response = DNS::Message.from_socket socket - if answer = response.answers.first? - data = answer.data - case data - when DNS::IPv4Address - ip = data.to_slice - StaticArray(UInt8, 4).new do |i| - ip[i] - end - else raise "Unexpected response type from DNS request: #{data.inspect}" - end - else - raise "No A response from myip.opendns.com" - end - ensure - socket.try &.close - end + URLS = { + "http://ipv6.icanhazip.com", + "http://ipv4.icanhazip.com", + } + + def self.ipv4 : String? + by_http.find { |ip| !ip.includes?(':') } end - # ifconfig.co/ip is another option - def self.by_http : StaticArray(UInt8, 4) - with_cache do - resp = HTTP::Client.get("http://checkip.amazonaws.com") - raise "Could not retrive public ip" unless resp.status_code == 200 - str_to_arr resp.body - end + def self.ipv6 : String? + by_http.find(&.includes?(':')) end - private def self.str_to_arr(str : String) : StaticArray(UInt8, 4) - ip = StaticArray(UInt8, 4).new(0_u8) - i = 0 - str.split(".") do |part| - ip[i] = part.to_u8 - i += 1 + # icanhazip.com is from Cloudflare + # returns stripped IP addresses as strings, one per URL in URLS + def self.by_http : Array(String) + with_cache do + ips = URLS.compact_map do |url| + resp = HTTP::Client.get(url) + next unless resp.status_code == 200 + resp.body.chomp + rescue + nil + end + raise "No valid response from icanhazip.com" if ips.empty? + ips end - ip end CACHE_PATH = ENV.fetch("SPAROID_CACHE_PATH", "/tmp/.sparoid_public_ip") - private def self.with_cache(&blk : -> StaticArray(UInt8, 4)) : StaticArray(UInt8, 4) + private def self.with_cache(&blk : -> Array(String)) : Array(String) if up_to_date_cache? read_cache else @@ -68,23 +50,27 @@ module Sparoid false end - private def self.read_cache : StaticArray(UInt8, 4) + private def self.read_cache : Array(String) File.open(CACHE_PATH, "r") do |file| file.flock_shared - str_to_arr(file.gets_to_end) + Array(String).new.tap do |ips| + while line = file.gets + ips << line + end + end end end - private def self.write_cache(& : -> StaticArray(UInt8, 4)) : StaticArray(UInt8, 4) + private def self.write_cache(& : -> Array(String)) : Array(String) File.open(CACHE_PATH, "a", 0o0644) do |file| file.flock_exclusive - ip = yield - file.truncate - ip.each_with_index do |e, i| - file.print '.' unless i.zero? - file.print e + ips = yield + file.truncate(0) + file.rewind + ips.each do |ip| + file.puts ip end - ip + ips end end end diff --git a/src/server-cli.cr b/src/server-cli.cr index b284639..0689cfd 100644 --- a/src/server-cli.cr +++ b/src/server-cli.cr @@ -8,16 +8,30 @@ begin puts "Listening: #{c.host}:#{c.port}" puts "Keys: #{c.keys.size}" puts "HMAC keys: #{c.hmac_keys.size}" - if c.nftables_cmd.bytesize > 0 + if c.nftables_cmd.bytesize > 0 || c.nftablesv6_cmd.bytesize > 0 puts "nftables command: #{c.nftables_cmd}" + puts "nftablesv6 command: #{c.nftablesv6_cmd}" nft = Nftables.new - on_accept = ->(ip_str : String) { - nft.run_cmd sprintf(c.nftables_cmd, ip_str) + on_accept = ->(ip_str : String, family : Socket::Family) : Nil { + case family + when Socket::Family::INET6 + if c.nftablesv6_cmd.bytesize > 0 + nft.run_cmd sprintf(c.nftablesv6_cmd, ip_str) + else + puts "WARNING: no nftablesv6-cmd configured, skipping #{ip_str}" + end + when Socket::Family::INET + if c.nftables_cmd.bytesize > 0 + nft.run_cmd sprintf(c.nftables_cmd, ip_str) + else + puts "WARNING: no nftables-cmd configured, skipping #{ip_str}" + end + end } else puts "Open command: #{c.open_cmd}" puts "Close command: #{c.close_cmd}" - on_accept = ->(ip_str : String) : Nil { + on_accept = ->(ip_str : String, _family : Socket::Family) : Nil { spawn do system sprintf(c.open_cmd, ip_str) unless c.close_cmd.empty? diff --git a/src/server.cr b/src/server.cr index c6c846f..0d77d7e 100644 --- a/src/server.cr +++ b/src/server.cr @@ -8,7 +8,7 @@ module Sparoid @keys : Array(Bytes) @hmac_keys : Array(Bytes) - def initialize(keys : Enumerable(String), hmac_keys : Enumerable(String), @on_accept : Proc(String, Nil), @address : Socket::IPAddress) + def initialize(keys : Enumerable(String), hmac_keys : Enumerable(String), @on_accept : Proc(String, Socket::Family, Nil), @address : Socket::IPAddress) @keys = keys.map &.hexbytes @hmac_keys = hmac_keys.map &.hexbytes raise ArgumentError.new("Key must be 32 bytes hex encoded") if @keys.any? { |k| k.bytesize != 32 } @@ -43,8 +43,8 @@ module Sparoid msg = Message.from_io(plain, IO::ByteFormat::NetworkEndian) verify_ts(msg.ts) verify_nounce(msg.nounce) - ip_str = ip_to_s(msg.ip) - @on_accept.call(ip_str) + ip_str = msg.ip_string + @on_accept.call(ip_str, msg.family) ip_str end @@ -57,7 +57,7 @@ module Sparoid private def verify_nounce(nounce) if @seen_nounces.includes? nounce - raise "reply-attack, nounce seen before" + raise "replay-attack, nounce seen before" end @seen_nounces.shift if @seen_nounces.size >= MAX_NOUNCES @seen_nounces.push nounce @@ -71,15 +71,6 @@ module Sparoid end end - private def ip_to_s(ip) - String.build(15) do |str| - ip.each_with_index do |part, i| - str << '.' unless i == 0 - str << part - end - end - end - private def verify_packet(data : Bytes) : Bytes packet_mac = data[0, 32] data += 32