diff --git a/utils/ips_test.go b/utils/ips_test.go index 5f94ede3e..7ac94c60e 100644 --- a/utils/ips_test.go +++ b/utils/ips_test.go @@ -877,3 +877,284 @@ func TestIPPoolFieldAsString(t *testing.T) { }) } } + +func TestIPPoolFieldAsInterface(t *testing.T) { + pool := &IPPool{ + ID: "FIRST_POOL", + FilterIDs: []string{"flt1", "flt2"}, + Type: "*ipv4", + Range: "192.168.122.1/24", + Strategy: "*ascending", + Message: "message", + Weights: DynamicWeights{ + &DynamicWeight{FilterIDs: nil, Weight: 15}, + }, + Blockers: DynamicBlockers{ + &DynamicBlocker{FilterIDs: nil, Blocker: false}, + }, + } + + tests := []struct { + name string + fldPath []string + want any + wantErr bool + }{ + { + name: "ID field", + fldPath: []string{"ID"}, + want: "FIRST_POOL", + wantErr: false, + }, + { + name: "Type field", + fldPath: []string{"Type"}, + want: "*ipv4", + wantErr: false, + }, + { + name: "Range field", + fldPath: []string{"Range"}, + want: "192.168.122.1/24", + wantErr: false, + }, + { + name: "Strategy field", + fldPath: []string{"Strategy"}, + want: "*ascending", + wantErr: false, + }, + { + name: "Message field", + fldPath: []string{"Message"}, + want: "message", + wantErr: false, + }, + { + name: "FilterIDs full", + fldPath: []string{"FilterIDs"}, + want: []string{"flt1", "flt2"}, + wantErr: false, + }, + { + name: "FilterIDs index out of range", + fldPath: []string{"FilterIDs:5"}, + want: nil, + wantErr: true, + }, + { + name: "Weights field", + fldPath: []string{"Weights"}, + want: pool.Weights, + wantErr: false, + }, + { + name: "Blockers field", + fldPath: []string{"Blockers"}, + want: pool.Blockers, + wantErr: false, + }, + { + name: "Invalid field", + fldPath: []string{"Unknown"}, + want: nil, + wantErr: true, + }, + { + name: "Too deep path", + fldPath: []string{"ID", "extra"}, + want: nil, + wantErr: true, + }, + { + name: "Empty path (whole object)", + fldPath: []string{}, + want: pool, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := pool.FieldAsInterface(tt.fldPath) + if (err != nil) != tt.wantErr { + t.Errorf("FieldAsInterface() error = %v, wantErr %v", err, tt.wantErr) + return + } + switch expected := tt.want.(type) { + case string: + gotStr, ok := got.(string) + if !ok || gotStr != expected { + t.Errorf("FieldAsInterface() = %v, want %v", got, expected) + } + case []string: + gotSlice, ok := got.([]string) + if !ok || !reflect.DeepEqual(gotSlice, expected) { + t.Errorf("FieldAsInterface() = %v, want %v", got, expected) + } + default: + if !reflect.DeepEqual(got, expected) { + t.Errorf("FieldAsInterface() = %v, want %v", got, expected) + } + } + }) + } +} + +func TestIPPoolSet(t *testing.T) { + + pool := &IPPool{} + + tests := []struct { + name string + path []string + val any + wantErr bool + check func(t *testing.T, p *IPPool) + }{ + { + name: "Set ID field", + path: []string{"ID"}, + val: "FIRST_POOL", + wantErr: false, + check: func(t *testing.T, p *IPPool) { + if p.ID != "FIRST_POOL" { + t.Errorf("ID = %v, want %v", p.ID, "FIRST_POOL") + } + }, + }, + { + name: "Set FilterIDs field", + path: []string{"FilterIDs"}, + val: []string{"flt1", "flt2"}, + wantErr: false, + check: func(t *testing.T, p *IPPool) { + if len(p.FilterIDs) != 2 || p.FilterIDs[0] != "flt1" || p.FilterIDs[1] != "flt2" { + t.Errorf("FilterIDs = %v, want [flt1 flt2]", p.FilterIDs) + } + }, + }, + { + name: "Set Type field", + path: []string{"Type"}, + val: "*ipv4", + wantErr: false, + check: func(t *testing.T, p *IPPool) { + if p.Type != "*ipv4" { + t.Errorf("Type = %v, want %v", p.Type, "*ipv4") + } + }, + }, + { + name: "Set Range field", + path: []string{"Range"}, + val: "192.168.122.1/24", + wantErr: false, + check: func(t *testing.T, p *IPPool) { + if p.Range != "192.168.122.1/24" { + t.Errorf("Range = %v, want %v", p.Range, "192.168.122.1/24") + } + }, + }, + { + name: "Set Strategy field", + path: []string{"Strategy"}, + val: "*ascending", + wantErr: false, + check: func(t *testing.T, p *IPPool) { + if p.Strategy != "*ascending" { + t.Errorf("Strategy = %v, want %v", p.Strategy, "*ascending") + } + }, + }, + { + name: "Set Message field", + path: []string{"Message"}, + val: "Some message", + wantErr: false, + check: func(t *testing.T, p *IPPool) { + if p.Message != "Some message" { + t.Errorf("Message = %v, want %v", p.Message, "Some message") + } + }, + }, + { + name: "Set Weights with valid string", + path: []string{"Weights"}, + val: "flt1&flt2;15;flt3;20", + wantErr: false, + check: func(t *testing.T, p *IPPool) { + if len(p.Weights) != 2 { + t.Errorf("Weights count = %d, want 2", len(p.Weights)) + return + } + if p.Weights[0].Weight != 15 { + t.Errorf("Weights[0].Weight = %v, want 15", p.Weights[0].Weight) + } + if p.Weights[1].Weight != 20 { + t.Errorf("Weights[1].Weight = %v, want 20", p.Weights[1].Weight) + } + if len(p.Weights[0].FilterIDs) != 2 || p.Weights[0].FilterIDs[0] != "flt1" { + t.Errorf("Weights[0].FilterIDs = %v, want [flt1 flt2]", p.Weights[0].FilterIDs) + } + }, + }, + { + name: "Set Blockers with valid string", + path: []string{"Blockers"}, + val: "flt1&flt2;false;flt3;true", + wantErr: false, + check: func(t *testing.T, p *IPPool) { + if len(p.Blockers) != 2 { + t.Errorf("Blockers count = %d, want 2", len(p.Blockers)) + return + } + if p.Blockers[0].Blocker != false { + t.Errorf("Blockers[0].Blocker = %v, want false", p.Blockers[0].Blocker) + } + if p.Blockers[1].Blocker != true { + t.Errorf("Blockers[1].Blocker = %v, want true", p.Blockers[1].Blocker) + } + if len(p.Blockers[0].FilterIDs) != 2 || p.Blockers[0].FilterIDs[0] != "flt1" { + t.Errorf("Blockers[0].FilterIDs = %v, want [flt1 flt2]", p.Blockers[0].FilterIDs) + } + }, + }, + { + name: "Set with wrong path length", + path: []string{"ID", "extra"}, + val: "bad", + wantErr: true, + }, + { + name: "Set unknown field", + path: []string{"Unknown"}, + val: "value", + wantErr: true, + }, + { + name: "Set Weights with invalid string", + path: []string{"Weights"}, + val: "flt1;badweight", + wantErr: true, + }, + { + name: "Set Blockers with invalid string", + path: []string{"Blockers"}, + val: "flt1;notabool", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := pool.Set(tt.path, tt.val, false) + if (err != nil) != tt.wantErr { + t.Fatalf("Set() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.check != nil && !tt.wantErr { + tt.check(t, pool) + } + }) + } +}