/*
Copyright 2022 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package hana

import (
	"context"
	"os"
	"testing"

	"github.com/google/go-cmp/cmp"
	"github.com/google/go-cmp/cmp/cmpopts"
	"google.golang.org/protobuf/testing/protocmp"
	"github.com/GoogleCloudPlatform/sapagent/internal/sapcontrolclient"
	"github.com/GoogleCloudPlatform/sapagent/internal/sapcontrolclient/test/sapcontrolclienttest"
	"github.com/GoogleCloudPlatform/sapagent/internal/system"
	"github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/commandlineexecutor"
	"github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/log"

	mrpb "google.golang.org/genproto/googleapis/monitoring/v3"
	cpb "github.com/GoogleCloudPlatform/sapagent/protos/configuration"
	iipb "github.com/GoogleCloudPlatform/sapagent/protos/instanceinfo"
	sapb "github.com/GoogleCloudPlatform/sapagent/protos/sapapp"
)

func TestMain(t *testing.M) {
	log.SetupLoggingForTest()
	os.Exit(t.Run())
}

var (
	defaultSAPInstance = &sapb.SAPInstance{
		Sapsid:         "TST",
		InstanceNumber: "00",
		ServiceName:    "test-service",
		Type:           sapb.InstanceType_HANA,
		Site:           sapb.InstanceSite_HANA_PRIMARY,
		HanaHaMembers:  []string{"test-instance-1", "test-instance-2"},
		HanaDbUser:     "test-user",
		HanaDbPassword: "test-pass",
	}

	defaultConfig = &cpb.Configuration{
		CollectionConfiguration: &cpb.CollectionConfiguration{
			CollectProcessMetrics:   false,
			ProcessMetricsFrequency: 5,
		},
		CloudProperties: &iipb.CloudProperties{
			ProjectId:        "test-project",
			InstanceId:       "test-instance",
			Zone:             "test-zone",
			InstanceName:     "test-instance",
			Image:            "test-image",
			NumericProjectId: "123456",
		},
	}

	defaultInstanceProperties = &InstanceProperties{
		Config:      defaultConfig,
		SAPInstance: defaultSAPInstance,
	}

	defaultAPIInstanceProperties = &InstanceProperties{
		Config:      defaultConfig,
		SAPInstance: defaultSAPInstance,
	}

	instancePropertiesWithReplication = &InstanceProperties{
		Config:            defaultConfig,
		SAPInstance:       defaultSAPInstance,
		ReplicationConfig: defaultReplicationConfig,
	}

	instancePropertiesWithReplicationDisabled = &InstanceProperties{
		Config:            defaultConfig,
		SAPInstance:       defaultSAPInstance,
		ReplicationConfig: replicationConfigForStandaloneInstance,
	}

	instancePropertiesWithReplicationRefreshFailure = &InstanceProperties{
		Config:            defaultConfig,
		SAPInstance:       defaultSAPInstance,
		ReplicationConfig: replicationConfigForRefreshFailure,
	}

	defaultReplicationConfig = func(ctx context.Context, user, sid, instID string, sapSystemInterface system.SapSystemDiscoveryInterface) (int, int64, *sapb.HANAReplicaSite, error) {
		return 1, 1, nil, nil
	}

	replicationConfigForStandaloneInstance = func(ctx context.Context, user, sid, instID string, sapSystemInterface system.SapSystemDiscoveryInterface) (int, int64, *sapb.HANAReplicaSite, error) {
		return 0, 1, nil, nil
	}

	replicationConfigForRefreshFailure = func(ctx context.Context, user, sid, instID string, sapSystemInterface system.SapSystemDiscoveryInterface) (int, int64, *sapb.HANAReplicaSite, error) {
		return 0, 1, nil, cmpopts.AnyError
	}

	defaultSapControlOutput = `OK
		0 name: hdbdaemon
		0 dispstatus: GREEN
		0 pid: 111
		1 name: hdbcompileserver
		1 dispstatus: GREEN
		1 pid: 222
		2 name: hdbindexserver
		2 dispstatus: GREEN
		2 pid: 333
		3 name: hdbnameserver
		3 dispstatus: GREEN
		3 pid: 444
		4 name: hdbpreprocessor
		4 dispstatus: GREEN
		4 pid: 555
		5 name: hdbwebdispatcher
		5 dispstatus: GREEN
		5 pid: 666
		6 name: hdbxsengine
		6 dispstatus: GREEN
		6 pid: 777`
)

type fakeRunner struct {
	stdOut, stdErr string
	exitCode       int
	err            error
}

func (f *fakeRunner) RunWithEnv() (string, string, int, error) {
	return f.stdOut, f.stdErr, f.exitCode, f.err
}

func TestCollectHANAServiceMetrics(t *testing.T) {
	tests := []struct {
		name               string
		fakeClient         sapcontrolclienttest.Fake
		wantMetricCount    int
		wantErr            error
		instanceProperties *InstanceProperties
	}{
		{
			name: "SuccessWebmethod",
			fakeClient: sapcontrolclienttest.Fake{
				Processes: []sapcontrolclient.OSProcess{
					{Name: "hdbdaemon", Dispstatus: "SAPControl-GREEN", Pid: 9609},
					{Name: "hdbcompileserver", Dispstatus: "SAPControl-GREEN", Pid: 9972},
					{Name: "hdbindexserver", Dispstatus: "SAPControl-GREEN", Pid: 10013},
					{Name: "hdbnameserver", Dispstatus: "SAPControl-GREEN", Pid: 9642},
					{Name: "hdbpreprocessor", Dispstatus: "SAPControl-GREEN", Pid: 9975},
					{Name: "hdbwebdispatcher", Dispstatus: "SAPControl-GREEN", Pid: 666},
					{Name: "hdbxsengine", Dispstatus: "SAPControl-GREEN", Pid: 777},
				},
			},
			wantMetricCount:    7,
			instanceProperties: defaultAPIInstanceProperties,
		},
		{
			name:               "FailureWebmethodGetProcessList",
			fakeClient:         sapcontrolclienttest.Fake{ErrGetProcessList: cmpopts.AnyError},
			wantMetricCount:    0,
			wantErr:            cmpopts.AnyError,
			instanceProperties: defaultAPIInstanceProperties,
		},
		{
			name: "EmptyProcessList",
			fakeClient: sapcontrolclienttest.Fake{
				Processes: []sapcontrolclient.OSProcess{},
			},
			wantMetricCount:    0,
			instanceProperties: defaultAPIInstanceProperties,
		},
		{
			name: "MetricsSkipped",
			fakeClient: sapcontrolclienttest.Fake{
				Processes: []sapcontrolclient.OSProcess{
					{Name: "hdbdaemon", Dispstatus: "SAPControl-GREEN", Pid: 9609},
					{Name: "hdbcompileserver", Dispstatus: "SAPControl-GREEN", Pid: 9972},
					{Name: "hdbindexserver", Dispstatus: "SAPControl-GREEN", Pid: 10013},
					{Name: "hdbnameserver", Dispstatus: "SAPControl-GREEN", Pid: 9642},
					{Name: "hdbpreprocessor", Dispstatus: "SAPControl-GREEN", Pid: 9975},
					{Name: "hdbwebdispatcher", Dispstatus: "SAPControl-GREEN", Pid: 666},
				},
			},
			instanceProperties: &InstanceProperties{
				SAPInstance: defaultSAPInstance,
				Config: &cpb.Configuration{
					CollectionConfiguration: &cpb.CollectionConfiguration{
						ProcessMetricsToSkip: []string{servicePath},
					},
				},
				SkippedMetrics: map[string]bool{servicePath: true},
			},
			wantMetricCount: 0,
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			metrics, err := collectHANAServiceMetrics(context.Background(), test.instanceProperties, test.fakeClient)
			if len(metrics) != test.wantMetricCount {
				t.Errorf("collectHANAServiceMetrics() metric count mismatch, got: %v want: %v.", len(metrics), test.wantMetricCount)
			}
			if !cmp.Equal(err, test.wantErr, cmpopts.EquateErrors()) {
				t.Errorf("collectHANAServiceMetrics() gotErr: %v wantErr: %v.", err, test.wantErr)
			}
		})
	}
}

func TestRunHANAQuery(t *testing.T) {
	successOutput := `| D |
	| - |
	| X |
	1 row selected (overall time 1187 usec; server time 509 usec)`

	tests := []struct {
		name           string
		fakeExec       commandlineexecutor.Execute
		ip             *InstanceProperties
		wantQueryState queryState
		wantErr        error
	}{
		{
			name: "Success",
			fakeExec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				return commandlineexecutor.Result{
					StdOut:   successOutput,
					ExitCode: 0,
				}
			},
			ip: defaultInstanceProperties,
			wantQueryState: queryState{
				state:       0,
				overallTime: 1187,
				serverTime:  509,
			},
			wantErr: nil,
		},
		{
			name: "NonZeroState",
			fakeExec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				return commandlineexecutor.Result{
					StdOut:   successOutput,
					ExitCode: 100,
				}
			},
			ip: defaultInstanceProperties,
			wantQueryState: queryState{
				state:       100,
				overallTime: 1187,
				serverTime:  509,
			},
			wantErr: nil,
		},
		{
			name: "ExitCodeZeroWithError",
			fakeExec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				return commandlineexecutor.Result{
					StdOut:   "(overall time 10 usec; server time 10 usec)",
					StdErr:   "Not Found.",
					ExitCode: 0,
					Error:    cmpopts.AnyError,
				}
			},
			ip: defaultInstanceProperties,
			wantQueryState: queryState{
				state:       0,
				overallTime: 10,
				serverTime:  10,
			},
			wantErr: nil,
		},
		{
			name: "ParseOverallTimeFailure",
			ip:   defaultInstanceProperties,
			fakeExec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				return commandlineexecutor.Result{
					StdOut:   "(overall time invalid-int; server time 509 usec).",
					ExitCode: 0,
				}
			},
			wantErr: cmpopts.AnyError,
		},
		{
			name: "ParseServerTimeFailure",
			ip:   defaultInstanceProperties,
			fakeExec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				return commandlineexecutor.Result{
					StdOut:   "(overall time 1187 usec; server time invalid-int)",
					ExitCode: 128,
				}
			},
			wantErr: cmpopts.AnyError,
		},
		{
			name: "IntegerOverflow",
			ip:   defaultInstanceProperties,
			fakeExec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				return commandlineexecutor.Result{
					StdOut:   "(overall time 100000000000000000000 usec; server time 10 usec)",
					ExitCode: 0,
				}
			},
			wantErr: cmpopts.AnyError,
		},
		{
			name: "AuthenticationFailed",
			ip:   defaultInstanceProperties,
			fakeExec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				return commandlineexecutor.Result{
					StdErr:   "* 10: authentication failed SQLSTATE: 28000\n",
					ExitCode: 3,
				}
			},
			wantErr: cmpopts.AnyError,
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			gotQueryState, gotErr := runHANAQuery(context.Background(), test.ip, test.fakeExec)

			if !cmp.Equal(gotErr, test.wantErr, cmpopts.EquateErrors()) {
				t.Errorf("runHANAQuery(), gotErr: %v wantErr: %v.", gotErr, test.wantErr)
			}

			if diff := cmp.Diff(test.wantQueryState, gotQueryState, cmp.AllowUnexported(queryState{})); diff != "" {
				t.Errorf("runHANAQuery(), diff (-want +got):\n%s", diff)
			}
		})
	}
}

func TestCollectHANAQueryMetrics(t *testing.T) {
	fakeExec := func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
		return commandlineexecutor.Result{
			StdOut:   "1 row selected (overall time 1187 usec; server time 509 usec)",
			ExitCode: 0,
		}
	}
	got, _ := collectHANAQueryMetrics(context.Background(), defaultInstanceProperties, fakeExec)
	if len(got) != 3 {
		t.Errorf("collectHANAQueryMetrics(), got: %d want: 3.", len(got))
	}
}

func TestCollectHANAQueryMetricsWithMaxFailCounts(t *testing.T) {
	fakeExec := func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
		return commandlineexecutor.Result{
			StdErr:   "* 10: authentication failed SQLSTATE: 28000\n",
			ExitCode: 3,
		}
	}
	ip := &InstanceProperties{
		Config:             defaultConfig,
		SAPInstance:        defaultSAPInstance,
		HANAQueryFailCount: 0,
	}

	for i := 0; i < 3; i++ {
		got, _ := collectHANAQueryMetrics(context.Background(), ip, fakeExec)
		switch i {
		case 0, 1:
			ts := got[0].GetPoints()[0].GetInterval().GetEndTime()
			want := []*mrpb.TimeSeries{createMetrics(ip, queryStatePath, nil, ts, int64(1))}
			if cmp.Diff(got[0], want[0], protocmp.Transform()) != "" {
				t.Errorf("collectHANAQueryMetrics(), got: %v want: %v.", got, want)
			}
		default:
			if got != nil {
				t.Errorf("collectHANAQueryMetrics(), got: %v want: nil.", got)
			}
		}
	}
}

func TestCollectHANALogUtilisationKb(t *testing.T) {
	tests := []struct {
		name                  string
		wantMetricCount       int
		wantErr               error
		instanceProperties    *InstanceProperties
		exec                  commandlineexecutor.Execute
		checkLabels           bool
		expectedDiskSizeValue string
	}{
		{
			name: "SuccessfulCollection",
			exec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				if params.Executable == "du" {
					return commandlineexecutor.Result{
						StdOut: "1234",
					}
				}
				return commandlineexecutor.Result{
					StdOut: "test_value\n 888 999",
				}
			},
			wantMetricCount:       1,
			instanceProperties:    instancePropertiesWithReplication,
			checkLabels:           true,
			expectedDiskSizeValue: "999",
		},
		{
			name: "FailedCollectionOnStandaloneInstance",
			exec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				if params.Executable == "du" {
					return commandlineexecutor.Result{
						StdOut: "1234",
					}
				}
				return commandlineexecutor.Result{
					StdOut: "test_value/n 888 999",
				}
			},
			wantMetricCount:    0,
			instanceProperties: instancePropertiesWithReplicationDisabled,
		},
		{
			name: "FailedCollectionWhenUnableToFetchReplicationSite",
			exec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				if params.Executable == "du" {
					return commandlineexecutor.Result{
						StdOut: "1234",
					}
				}
				return commandlineexecutor.Result{
					StdOut: "test_value/n 888 999",
				}
			},
			wantMetricCount:    0,
			wantErr:            cmpopts.AnyError,
			instanceProperties: instancePropertiesWithReplicationRefreshFailure,
		},
		{
			name: "FailedCollectionWhenDuCommandFails",
			exec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				if params.Executable == "du" {
					return commandlineexecutor.Result{
						Error: cmpopts.AnyError,
					}
				}
				return commandlineexecutor.Result{
					StdOut: "test_value/n 888 999",
				}
			},
			wantMetricCount:    0,
			wantErr:            cmpopts.AnyError,
			instanceProperties: instancePropertiesWithReplication,
		},
		{
			name: "FailedCollectionWhenDuCommandReturnsInvalidOutput",
			exec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				if params.Executable == "du" {
					return commandlineexecutor.Result{
						StdOut: "",
					}
				}
				return commandlineexecutor.Result{
					StdOut: "test_value/n 888 999",
				}
			},
			wantMetricCount:    0,
			wantErr:            nil,
			instanceProperties: instancePropertiesWithReplication,
		},
		{
			name: "SuccessfulCollectionWithoutLabelWhenCommandFails",
			exec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				if params.Executable == "du" {
					return commandlineexecutor.Result{
						StdOut: "1234",
					}
				}
				return commandlineexecutor.Result{
					Error: cmpopts.AnyError,
				}
			},
			wantMetricCount:       1,
			instanceProperties:    instancePropertiesWithReplication,
			checkLabels:           true,
			expectedDiskSizeValue: "",
		},
		{
			name: "SuccessfulCollectionWithoutLabelWhenCommandReturnsInvalidOutput1",
			exec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				if params.Executable == "du" {
					return commandlineexecutor.Result{
						StdOut: "1234",
					}
				}
				return commandlineexecutor.Result{
					StdOut: "test_value",
				}
			},
			wantMetricCount:       1,
			instanceProperties:    instancePropertiesWithReplication,
			checkLabels:           true,
			expectedDiskSizeValue: "",
		},
		{
			name: "SuccessfulCollectionWithoutLabelWhenCommandReturnsInvalidOutput2",
			exec: func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result {
				if params.Executable == "du" {
					return commandlineexecutor.Result{
						StdOut: "1234",
					}
				}
				return commandlineexecutor.Result{
					StdOut: "test_value\n 888",
				}
			},
			wantMetricCount:       1,
			instanceProperties:    instancePropertiesWithReplication,
			checkLabels:           true,
			expectedDiskSizeValue: "",
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			metrics, err := collectHANALogUtilisationKb(context.Background(), test.instanceProperties, test.exec)
			if len(metrics) != test.wantMetricCount {
				t.Errorf("collectHANALogUtilisationKb() metric count mismatch, got: %v want: %v.", len(metrics), test.wantMetricCount)
			}
			if !cmp.Equal(err, test.wantErr, cmpopts.EquateErrors()) {
				t.Errorf("collectHANALogUtilisationKb() gotErr: %v wantErr: %v.", err, test.wantErr)
			}
			if test.checkLabels && len(metrics) > 0 {
				labels := metrics[0].GetMetric().GetLabels()
				if labels["hana_log_disk_size_kb"] != test.expectedDiskSizeValue {
					t.Errorf("collectHANALogUtilisationKb() hana_log_disk_size_kb label mismatch, got: %v want: %v.", labels["hana_log_disk_size_kb"], test.expectedDiskSizeValue)
				}
			}
		})
	}
}

func TestCollect(t *testing.T) {
	tests := []struct {
		name       string
		properties *InstanceProperties
		wantCount  int
		wantErr    error
	}{
		{
			name: "MetricCountTest",
			properties: &InstanceProperties{
				Config: defaultConfig,
				SAPInstance: &sapb.SAPInstance{
					Sapsid:         "TST",
					InstanceNumber: "00",
					HanaDbUser:     "test-user",
					HanaDbPassword: "test-pass",
				},
				SkippedMetrics: map[string]bool{
					servicePath:          true,
					logUtilisationKbPath: true,
				},
			},
			wantCount: 1, // Without HANA setup in unit test ENV, only query/state metric is generated.
			wantErr:   nil,
		},
		{
			name: "MetricCountTestUserstoreAuth",
			properties: &InstanceProperties{
				Config: defaultConfig,
				SAPInstance: &sapb.SAPInstance{
					Sapsid:          "TST",
					InstanceNumber:  "00",
					HdbuserstoreKey: "test-key",
				},
				SkippedMetrics: map[string]bool{
					servicePath:          true,
					logUtilisationKbPath: true,
				},
			},
			wantCount: 1, // Without HANA setup in unit test ENV, only query/state metric is generated.
			wantErr:   nil,
		},
		{
			name: "NoHANADBUserAndKey",
			properties: &InstanceProperties{
				Config: defaultConfig,
				SAPInstance: &sapb.SAPInstance{
					Sapsid:         "TST",
					InstanceNumber: "00",
				},
				SkippedMetrics: map[string]bool{
					servicePath:          true,
					logUtilisationKbPath: true,
				},
			},
			wantCount: 0, // Query state metric not generated without credentials.
		},
		{
			name: "NoHANADBUser",
			properties: &InstanceProperties{
				Config: defaultConfig,
				SAPInstance: &sapb.SAPInstance{
					Sapsid:         "TST",
					InstanceNumber: "00",
					HanaDbPassword: "test-pass",
				},
				SkippedMetrics: map[string]bool{
					servicePath:          true,
					logUtilisationKbPath: true,
				},
			},
			wantCount: 0, // Query state metric not generated without credentials.
		},
		{
			name: "NoHANADBPassword",
			properties: &InstanceProperties{
				Config: defaultConfig,
				SAPInstance: &sapb.SAPInstance{
					Sapsid:         "TST",
					InstanceNumber: "00",
					HanaDbUser:     "test-user",
				},
				SkippedMetrics: map[string]bool{
					servicePath:          true,
					logUtilisationKbPath: true,
				},
			},
			wantCount: 0, // Query state metric not generated without credentials.
		},
		{
			name: "HANASecondaryNode",
			properties: &InstanceProperties{
				Config: defaultConfig,
				SAPInstance: &sapb.SAPInstance{
					Sapsid:         "TST",
					InstanceNumber: "00",
					HanaDbUser:     "test-user",
					HanaDbPassword: "test-pass",
					Site:           sapb.InstanceSite_HANA_SECONDARY,
				},
				SkippedMetrics: map[string]bool{
					servicePath:          true,
					logUtilisationKbPath: true,
				},
			},
			wantCount: 0, // Query state metric not generated for HANA secondary.
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			got, gotErr := test.properties.Collect(context.Background())
			if len(got) != test.wantCount {
				t.Errorf("Collect(), got: %d want: %d.", len(got), test.wantCount)
			}
			if !cmp.Equal(gotErr, test.wantErr, cmpopts.EquateErrors()) {
				t.Errorf("Collect(), gotErr: %v wantErr: %v.", gotErr, test.wantErr)
			}
		})
	}
}
