diff --git a/backend/analytics_server/mhq/exapi/github.py b/backend/analytics_server/mhq/exapi/github.py index 50cd1929d..dac20b009 100644 --- a/backend/analytics_server/mhq/exapi/github.py +++ b/backend/analytics_server/mhq/exapi/github.py @@ -22,12 +22,18 @@ class GithubRateLimitExceeded(Exception): class GithubApiService: - def __init__(self, access_token: str): + def __init__(self, access_token: str, domain: Optional[str]): self._token = access_token - self._g = Github(self._token, per_page=PAGE_SIZE) - self.base_url = "https://api.github.com" + self.base_url = self._get_api_url(domain) + self._g = Github(self._token, base_url=self.base_url, per_page=PAGE_SIZE) self.headers = {"Authorization": f"Bearer {self._token}"} + def _get_api_url(self, domain: str) -> str: + if not domain: + return "https://api.github.com" + else: + return f"{domain}/api/v3" + @contextlib.contextmanager def temp_config(self, per_page: int = 30): self._g.per_page = per_page diff --git a/backend/analytics_server/mhq/service/code/sync/etl_github_handler.py b/backend/analytics_server/mhq/service/code/sync/etl_github_handler.py index 8f5c1acf8..c208a6cbc 100644 --- a/backend/analytics_server/mhq/service/code/sync/etl_github_handler.py +++ b/backend/analytics_server/mhq/service/code/sync/etl_github_handler.py @@ -3,6 +3,7 @@ from typing import List, Dict, Optional, Tuple, Set import pytz +from mhq.utils.github import get_custom_github_domain from github.PaginatedList import PaginatedList as GithubPaginatedList from github.PullRequest import PullRequest as GithubPullRequest from github.PullRequestReview import PullRequestReview as GithubPullRequestReview @@ -371,7 +372,7 @@ def _get_access_token(): return GithubETLHandler( org_id, - GithubApiService(_get_access_token()), + GithubApiService(_get_access_token(), get_custom_github_domain(org_id)), CodeRepoService(), CodeETLAnalyticsService(), get_revert_prs_github_sync_handler(), diff --git a/backend/analytics_server/mhq/service/external_integrations_service.py b/backend/analytics_server/mhq/service/external_integrations_service.py index 9a4bcac5e..8d480be6e 100644 --- a/backend/analytics_server/mhq/service/external_integrations_service.py +++ b/backend/analytics_server/mhq/service/external_integrations_service.py @@ -26,7 +26,7 @@ def __init__( self.custom_domain = custom_domain def get_github_organizations(self): - github_api_service = GithubApiService(self.access_token) + github_api_service = GithubApiService(self.access_token, self.custom_domain) try: orgs: [GithubOrganization] = github_api_service.get_org_list() except GithubException as e: @@ -34,21 +34,21 @@ def get_github_organizations(self): return orgs def get_github_org_repos(self, org_login: str, page_size: int, page: int): - github_api_service = GithubApiService(self.access_token) + github_api_service = GithubApiService(self.access_token, self.custom_domain) try: return github_api_service.get_repos_raw(org_login, page_size, page) except GithubException as e: raise e def get_github_personal_repos(self, page_size: int, page: int): - github_api_service = GithubApiService(self.access_token) + github_api_service = GithubApiService(self.access_token, self.custom_domain) try: return github_api_service.get_user_repos_raw(page_size, page) except GithubException as e: raise e def get_repo_workflows(self, gh_org_name: str, gh_org_repo_name: str): - github_api_service = GithubApiService(self.access_token) + github_api_service = GithubApiService(self.access_token, self.custom_domain) try: workflows = github_api_service.get_repo_workflows( gh_org_name, gh_org_repo_name diff --git a/backend/analytics_server/mhq/service/workflows/sync/etl_github_actions_handler.py b/backend/analytics_server/mhq/service/workflows/sync/etl_github_actions_handler.py index 92d718232..22024dca1 100644 --- a/backend/analytics_server/mhq/service/workflows/sync/etl_github_actions_handler.py +++ b/backend/analytics_server/mhq/service/workflows/sync/etl_github_actions_handler.py @@ -4,6 +4,7 @@ import pytz +from mhq.utils.github import get_custom_github_domain from mhq.exapi.github import GithubApiService from mhq.service.workflows.sync.etl_provider_handler import WorkflowProviderETLHandler from mhq.store.models import UserIdentityProvider @@ -181,5 +182,7 @@ def _get_access_token(): return access_token return GithubActionsETLHandler( - org_id, GithubApiService(_get_access_token()), WorkflowRepoService() + org_id, + GithubApiService(_get_access_token(), get_custom_github_domain(org_id)), + WorkflowRepoService(), ) diff --git a/backend/analytics_server/mhq/utils/github.py b/backend/analytics_server/mhq/utils/github.py index c42fabf72..de20c68e7 100644 --- a/backend/analytics_server/mhq/utils/github.py +++ b/backend/analytics_server/mhq/utils/github.py @@ -1,9 +1,12 @@ from queue import Queue from threading import Thread +from typing import Optional from github import Organization from mhq.utils.log import LOG +from mhq.store.repos.core import CoreRepoService +from mhq.store.models import UserIdentityProvider def github_org_data_multi_thread_worker(orgs: [Organization]) -> dict: @@ -48,3 +51,25 @@ def run(self): for worker in workers: r.update(worker.results) return r + + +def get_custom_github_domain(org_id: str) -> Optional[str]: + DEFAULT_DOMAIN = "https://api.github.com" + core_repo_service = CoreRepoService() + integrations = core_repo_service.get_org_integrations_for_names( + org_id, [UserIdentityProvider.GITHUB.value] + ) + + github_domain = ( + integrations[0].provider_meta.get("custom_domain") + if integrations[0].provider_meta + else None + ) + + if not github_domain: + LOG.warn( + f"Custom domain not found for intergration for org {org_id} and provider {UserIdentityProvider.GITHUB.value}" + ) + return DEFAULT_DOMAIN + + return github_domain diff --git a/backend/analytics_server/tests/exapi/__init__.py b/backend/analytics_server/tests/exapi/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/analytics_server/tests/exapi/test_github.py b/backend/analytics_server/tests/exapi/test_github.py new file mode 100644 index 000000000..ee51156cc --- /dev/null +++ b/backend/analytics_server/tests/exapi/test_github.py @@ -0,0 +1,40 @@ +import unittest +from unittest.mock import patch + +from mhq.exapi.github import GithubApiService, PAGE_SIZE + + +class DummyGithub: + def __init__(self, token, base_url=None, per_page=None): + self.token = token + self.base_url = base_url + self.per_page = per_page + + +class TestGithubApiService(unittest.TestCase): + + @patch("mhq.exapi.github.Github", new=DummyGithub) + def test_default_domain_sets_standard_api_url(self): + token = "deadpool" + service = GithubApiService(access_token=token, domain=None) + self.assertEqual(service.base_url, "https://api.github.com") + self.assertIsInstance(service._g, DummyGithub) + self.assertEqual(service._g.token, token) + self.assertEqual(service._g.base_url, "https://api.github.com") + self.assertEqual(service._g.per_page, PAGE_SIZE) + + @patch("mhq.exapi.github.Github", new=DummyGithub) + def test_empty_string_domain_uses_default_url(self): + token = "deadpool" + service = GithubApiService(access_token=token, domain="") + self.assertEqual(service.base_url, "https://api.github.com") + self.assertEqual(service._g.base_url, "https://api.github.com") + + @patch("mhq.exapi.github.Github", new=DummyGithub) + def test_custom_domain_appends_api_v3(self): + token = "deadpool" + custom_domain = "https://github.sujai.com" + service = GithubApiService(access_token=token, domain=custom_domain) + expected = f"{custom_domain}/api/v3" + self.assertEqual(service.base_url, expected) + self.assertEqual(service._g.base_url, expected) diff --git a/web-server/jest.config.js b/web-server/jest.config.js index d157fc27b..d8a8cca04 100644 --- a/web-server/jest.config.js +++ b/web-server/jest.config.js @@ -1,6 +1,7 @@ module.exports = { preset: 'ts-jest/presets/js-with-babel', // Use the TypeScript preset with Babel testEnvironment: 'jsdom', // Use jsdom as the test environment (for browser-like behavior) + setupFiles: ['/jest.setup.js'], moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'], testMatch: [ '**/__tests__/**/*.test.(ts|tsx|js|jsx)', diff --git a/web-server/jest.setup.js b/web-server/jest.setup.js new file mode 100644 index 000000000..83f38a5de --- /dev/null +++ b/web-server/jest.setup.js @@ -0,0 +1,3 @@ +const { TextEncoder, TextDecoder } = require('util'); +global.TextEncoder = TextEncoder; +global.TextDecoder = TextDecoder; diff --git a/web-server/pages/api/internal/[org_id]/__tests__/github.test.ts b/web-server/pages/api/internal/[org_id]/__tests__/github.test.ts new file mode 100644 index 000000000..1cd277f62 --- /dev/null +++ b/web-server/pages/api/internal/[org_id]/__tests__/github.test.ts @@ -0,0 +1,84 @@ +jest.mock('@/utils/db', () => ({ + db: jest.fn(), +})); + +import { db } from '@/utils/db'; +import * as githubUtils from '../utils'; +import { DEFAULT_GH_URL } from '@/constants/urls'; + +describe('GitHub URL utilities', () => { + afterEach(() => { + jest.resetAllMocks(); + }); + + describe('getGitHubCustomDomain', () => { + it('returns custom_domain when present', async () => { + const mockMeta = [{ custom_domain: 'custom.sujai.com' }]; + (db as jest.Mock).mockReturnValue({ + where: jest.fn().mockReturnThis(), + then: jest.fn().mockResolvedValue(mockMeta) + }); + + const domain = await githubUtils.getGitHubCustomDomain(); + expect(domain).toBe('custom.sujai.com'); + }); + + it('returns null when no provider_meta found', async () => { + (db as jest.Mock).mockReturnValue({ + where: jest.fn().mockReturnThis(), + then: jest.fn().mockResolvedValue([]) + }); + + const domain = await githubUtils.getGitHubCustomDomain(); + expect(domain).toBeNull(); + }); + + it('returns null on db error and logs error', async () => { + const consoleSpy = jest.spyOn(console, 'error').mockImplementation(); + (db as jest.Mock).mockImplementation(() => { + throw new Error('DB failure'); + }); + + const domain = await githubUtils.getGitHubCustomDomain(); + expect(domain).toBeNull(); + expect(consoleSpy).toHaveBeenCalledWith( + 'Error occured while getting custom domain from database:', + expect.any(Error) + ); + }); + }); + + describe('getGitHubRestApiUrl', () => { + it('uses default URL when no custom domain', async () => { + jest.spyOn(githubUtils, 'getGitHubCustomDomain').mockResolvedValue(null); + const url = await githubUtils.getGitHubRestApiUrl('path/to/repo'); + expect(url).toBe(`${DEFAULT_GH_URL}/path/to/repo`); + }); + + it('uses custom domain when provided', async () => { + jest.spyOn(githubUtils, 'getGitHubCustomDomain').mockResolvedValue('git.sujai.com'); + const url = await githubUtils.getGitHubRestApiUrl('repos/owner/repo'); + expect(url).toBe('https://git.sujai.com/api/v3/repos/owner/repo'); + }); + + it('normalizes multiple slashes in URL', async () => { + jest.spyOn(githubUtils, 'getGitHubCustomDomain').mockResolvedValue('git.sujai.com/'); + const url = await githubUtils.getGitHubRestApiUrl('/repos//owner//repo'); + expect(url).toBe('https://git.sujai.com/api/v3/repos/owner/repo'); + }); + }); + + describe('getGitHubGraphQLUrl', () => { + it('uses default GraphQL endpoint when no custom domain', async () => { + jest.spyOn(githubUtils, 'getGitHubCustomDomain').mockResolvedValue(null); + const url = await githubUtils.getGitHubGraphQLUrl(); + expect(url).toBe(`${DEFAULT_GH_URL}/graphql`); + }); + + it('uses custom domain for GraphQL endpoint', async () => { + jest.spyOn(githubUtils, 'getGitHubCustomDomain').mockResolvedValue('api.github.local'); + const url = await githubUtils.getGitHubGraphQLUrl(); + expect(url).toBe('https://api.github.local/api/graphql'); + }); + }); +}); diff --git a/web-server/pages/api/internal/[org_id]/utils.ts b/web-server/pages/api/internal/[org_id]/utils.ts index 3503c0116..1ff52c91b 100644 --- a/web-server/pages/api/internal/[org_id]/utils.ts +++ b/web-server/pages/api/internal/[org_id]/utils.ts @@ -5,8 +5,7 @@ import { Row } from '@/constants/db'; import { Integration } from '@/constants/integrations'; import { BaseRepo } from '@/types/resources'; import { db } from '@/utils/db'; - -const GITHUB_API_URL = 'https://api.github.com/graphql'; +import { DEFAULT_GH_URL } from '@/constants/urls'; type GithubRepo = { name: string; @@ -53,7 +52,7 @@ export const searchGithubRepos = async ( }; const searchRepoWithURL = async (searchString: string) => { - const apiUrl = `https://api.github.com/repos/${searchString}`; + const apiUrl = await getGitHubRestApiUrl(`repos/${searchString}`); const response = await axios.get(apiUrl); const repo = response.data; return [ @@ -104,7 +103,9 @@ export const searchGithubReposWithNames = async ( const queryString = `${searchString} in:name fork:true`; - const response = await fetch(GITHUB_API_URL, { + const githubApiUrl = await getGitHubGraphQLUrl(); + + const response = await fetch(githubApiUrl, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -304,3 +305,33 @@ const replaceURL = async (url: string): Promise => { return url; }; + +export const getGitHubCustomDomain = async (): Promise => { + try { + const provider_meta = await db('Integration') + .where('name', Integration.GITHUB) + .then((r: Row<'Integration'>[]) => r.map((item) => item.provider_meta)); + + return head(provider_meta || [])?.custom_domain || null; + } catch (error) { + console.error('Error occured while getting custom domain from database:', error); + return null; + } +}; + +const normalizeSlashes = (url: string) => + url.replace(/(? { + const customDomain = await getGitHubCustomDomain(); + const base = customDomain + ? `${customDomain}/api/v3` + : DEFAULT_GH_URL; + return normalizeSlashes(`${base}/${path}`); +}; + + +export const getGitHubGraphQLUrl = async (): Promise => { + const customDomain = await getGitHubCustomDomain(); + return customDomain ? `${customDomain}/api/graphql` : `${DEFAULT_GH_URL}/graphql`; +}; diff --git a/web-server/src/constants/urls.ts b/web-server/src/constants/urls.ts new file mode 100644 index 000000000..3d5de9624 --- /dev/null +++ b/web-server/src/constants/urls.ts @@ -0,0 +1 @@ +export const DEFAULT_GH_URL = 'https://api.github.com'; diff --git a/web-server/src/content/Dashboards/ConfigureGithubModalBody.tsx b/web-server/src/content/Dashboards/ConfigureGithubModalBody.tsx index 3ed2ab932..e1c36692c 100644 --- a/web-server/src/content/Dashboards/ConfigureGithubModalBody.tsx +++ b/web-server/src/content/Dashboards/ConfigureGithubModalBody.tsx @@ -17,6 +17,7 @@ import { linkProvider, getMissingPATScopes } from '@/utils/auth'; +import { checkDomainWithRegex } from '@/utils/domainCheck'; import { depFn } from '@/utils/fn'; export const ConfigureGithubModalBody: FC<{ @@ -25,10 +26,12 @@ export const ConfigureGithubModalBody: FC<{ const token = useEasyState(''); const { orgId } = useAuth(); const { enqueueSnackbar } = useSnackbar(); + const customDomain = useEasyState(''); const dispatch = useDispatch(); const isLoading = useBoolState(); const showError = useEasyState(''); + const showDomainError = useEasyState(''); const setError = useCallback( (error: string) => { @@ -37,62 +40,75 @@ export const ConfigureGithubModalBody: FC<{ }, [showError.set] ); + const setDomainError = useCallback( + (error: string) => { + depFn(showDomainError.set, error); + }, + [showDomainError.set] + ); const handleChange = (e: string) => { token.set(e); showError.set(''); }; + const handleDomainChange = (e: string) => { + customDomain.set(e); + showDomainError.set(''); + }; const handleSubmission = useCallback(async () => { if (!token.value) { setError('Please enter a valid token'); return; } - depFn(isLoading.true); - checkGitHubValidity(token.value) - .then(async (isValid) => { - if (!isValid) throw new Error('Invalid token'); - }) - .then(async () => { - try { - const res = await getMissingPATScopes(token.value); - if (res.length) { - throw new Error(`Token is missing scopes: ${res.join(', ')}`); - } - } catch (e) { - // @ts-ignore - throw new Error(e?.message, e); - } - }) - .then(async () => { - try { - return await linkProvider(token.value, orgId, Integration.GITHUB); - } catch (e: any) { - throw new Error( - `Failed to link Github${e?.message ? `: ${e?.message}` : ''}`, - e - ); - } - }) - .then(() => { - dispatch(fetchCurrentOrg()); - dispatch( - fetchTeams({ - org_id: orgId - }) - ); - enqueueSnackbar('Github linked successfully', { - variant: 'success', - autoHideDuration: 2000 - }); - onClose(); - }) - .catch((e) => { - setError(e.message); - console.error(`Error while linking token: ${e.message}`, e); - }) - .finally(isLoading.false); + if ( + customDomain.value && + !checkDomainWithRegex(customDomain.valueRef.current) + ) { + setDomainError('Please enter a valid domain'); + return; + } + + isLoading.true(); + try { + const isValid = await checkGitHubValidity( + token.value, + customDomain.valueRef.current + ); + if (!isValid) { + setError('Invalid token'); + return; + } + + const missingScopes = await getMissingPATScopes( + token.value, + customDomain.valueRef.current + ); + if (missingScopes.length > 0) { + setError(`Token is missing scopes: ${missingScopes.join(', ')}`); + return; + } + + await linkProvider(token.value, orgId, Integration.GITHUB, { + custom_domain: customDomain.valueRef.current + }); + + dispatch(fetchCurrentOrg()); + dispatch(fetchTeams({ org_id: orgId })); + enqueueSnackbar('Github linked successfully', { + variant: 'success', + autoHideDuration: 2000 + }); + onClose(); + } catch (e: any) { + setError(e.message || 'Unknown error'); + console.error(e); + } finally { + isLoading.false(); + } }, [ + token.value, + customDomain.value, dispatch, enqueueSnackbar, isLoading.false, @@ -100,9 +116,17 @@ export const ConfigureGithubModalBody: FC<{ onClose, orgId, setError, - token.value + setDomainError ]); + const isDomainInputFocus = useBoolState(false); + + const focusDomainInput = useCallback(() => { + if (!customDomain.value) + document.getElementById('gitlab-custom-domain')?.focus(); + else handleSubmission(); + }, [customDomain.value, handleSubmission]); + return ( @@ -115,6 +139,7 @@ export const ConfigureGithubModalBody: FC<{ e.stopPropagation(); e.nativeEvent.stopImmediatePropagation(); handleSubmission(); + focusDomainInput(); return; } }} @@ -150,6 +175,43 @@ export const ConfigureGithubModalBody: FC<{ + + + Custom domain + + { + if (e.key === 'Enter') { + e.preventDefault(); + e.stopPropagation(); + e.nativeEvent.stopImmediatePropagation(); + handleSubmission(); + return; + } + }} + error={!!showDomainError.value} + sx={{ width: '100%' }} + value={customDomain.value} + onChange={(e) => handleDomainChange(e.currentTarget.value)} + label={ + isDomainInputFocus.value || customDomain.value + ? 'Custom Domain' + : '(Optional)' + } + onFocus={isDomainInputFocus.true} + onBlur={isDomainInputFocus.false} + helperText={ + isDomainInputFocus.value || customDomain.value + ? 'Example: https://github.mycompany.com' + : '' + } + placeholder="https://github.mycompany.com" + /> + + + {showDomainError.value} + diff --git a/web-server/src/content/Dashboards/ConfigureGitlabModalBody.tsx b/web-server/src/content/Dashboards/ConfigureGitlabModalBody.tsx index 6aef9ff54..eef8a0942 100644 --- a/web-server/src/content/Dashboards/ConfigureGitlabModalBody.tsx +++ b/web-server/src/content/Dashboards/ConfigureGitlabModalBody.tsx @@ -17,6 +17,7 @@ import { checkGitLabValidity, getMissingGitLabScopes } from '@/utils/auth'; +import { checkDomainWithRegex } from '@/utils/domainCheck'; import { depFn } from '@/utils/fn'; export const ConfigureGitlabModalBody: FC<{ @@ -46,11 +47,6 @@ export const ConfigureGitlabModalBody: FC<{ [showDomainError.set] ); - const checkDomainWithRegex = (domain: string) => { - const regex = - /^(https?:\/\/)[a-zA-Z0-9]+([-.][a-zA-Z0-9]+)*\.[a-zA-Z]{2,}(:[0-9]{1,5})?(\/.*)?$/; - return regex.test(domain); - }; const handleTokenChange = (e: string) => { token.set(e); showScopeError.set(''); diff --git a/web-server/src/utils/__tests__/domainCheck.test.ts b/web-server/src/utils/__tests__/domainCheck.test.ts new file mode 100644 index 000000000..512269fe7 --- /dev/null +++ b/web-server/src/utils/__tests__/domainCheck.test.ts @@ -0,0 +1,40 @@ +import { checkDomainWithRegex } from '../domainCheck'; + +describe('checkDomainWithRegex', () => { + const validDomains = [ + 'http://example.com', + 'https://example.com', + 'https://sub.example.co.uk', + 'http://example.io:8080', + 'https://example.io:8080', + 'https://example.com/', + 'https://123domain.net', + 'http://my-domain.org' + ]; + + test.each(validDomains)('returns true for %s', (domain) => { + expect(checkDomainWithRegex(domain)).toBe(true); + }); + + const invalidDomains = [ + 'example.com', + 'ftp://example.com', + 'http:/example.com', + 'https//example.com', + 'https://-example.com', + 'https://example-.com', + 'https://example', + 'https://.com', + 'https://example:toolongtsadasds', + 'https://example.com:999999', + 'https://example .com', + 'https://example.com/ path', + '', + 'https://', + 'https:///' + ]; + + test.each(invalidDomains)('returns false for %s', (domain) => { + expect(checkDomainWithRegex(domain)).toBe(false); + }); +}); diff --git a/web-server/src/utils/auth.ts b/web-server/src/utils/auth.ts index 4cc2d93bb..22852b373 100644 --- a/web-server/src/utils/auth.ts +++ b/web-server/src/utils/auth.ts @@ -2,6 +2,7 @@ import axios from 'axios'; import { isNil, reject } from 'ramda'; import { Integration } from '@/constants/integrations'; +import { DEFAULT_GH_URL } from '@/constants/urls'; export const unlinkProvider = async (orgId: string, provider: Integration) => { return await axios.delete(`/api/resources/orgs/${orgId}/integration`, { @@ -28,10 +29,15 @@ export const linkProvider = async ( // GitHub functions export async function checkGitHubValidity( - good_stuff: string + good_stuff: string, + customDomain?: string ): Promise { try { - await axios.get('https://api.github.com/user', { + // if customDomain is provded, the host will be customDomain/api/v3 + // else it will be api.github.com(default) + const baseUrl = customDomain ? `${customDomain}/api/v3` : DEFAULT_GH_URL; + + await axios.get(`${baseUrl}/user`, { headers: { Authorization: `token ${good_stuff}` } @@ -43,9 +49,13 @@ export async function checkGitHubValidity( } const PAT_SCOPES = ['read:org', 'read:user', 'repo', 'workflow']; -export const getMissingPATScopes = async (pat: string) => { +export const getMissingPATScopes = async ( + pat: string, + customDomain?: string +) => { + const baseUrl = customDomain ? `${customDomain}/api/v3` : DEFAULT_GH_URL; try { - const response = await axios.get('https://api.github.com', { + const response = await axios.get(baseUrl, { headers: { Authorization: `token ${pat}` } diff --git a/web-server/src/utils/domainCheck.ts b/web-server/src/utils/domainCheck.ts new file mode 100644 index 000000000..bb92ebeec --- /dev/null +++ b/web-server/src/utils/domainCheck.ts @@ -0,0 +1,5 @@ +export const checkDomainWithRegex = (domain: string) => { + const regex = + /^(https?:\/\/)[A-Za-z0-9]+([-.][A-Za-z0-9]+)*\.[A-Za-z]{2,}(:[0-9]{1,5})?(\/\S*)?$/; + return regex.test(domain); +};